mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-15 02:08:28 +08:00
Compare commits
35 Commits
qlib_monit
...
high-freq-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56edc16089 | ||
|
|
2b8462d137 | ||
|
|
1979cac50a | ||
|
|
424a48d0fb | ||
|
|
202bbea272 | ||
|
|
6a22136366 | ||
|
|
603c282415 | ||
|
|
22abe852f7 | ||
|
|
e3f463010b | ||
|
|
80aa08215f | ||
|
|
b3893067f7 | ||
|
|
e6dfccce2f | ||
|
|
f9c30f9834 | ||
|
|
f164bf8411 | ||
|
|
1f28044d84 | ||
|
|
3cf0d27a07 | ||
|
|
bcae4bb22e | ||
|
|
f680a564a0 | ||
|
|
9cd41e5a81 | ||
|
|
e23022e9d8 | ||
|
|
ebbbec2a6c | ||
|
|
13d39e6bbc | ||
|
|
b96aab6bef | ||
|
|
700eef4164 | ||
|
|
31c7d72485 | ||
|
|
30ad1967a2 | ||
|
|
0c6cad1d7b | ||
|
|
a0f22571de | ||
|
|
6835b2f67e | ||
|
|
7c4971e566 | ||
|
|
70a9d42c7d | ||
|
|
bcadf47f32 | ||
|
|
4dc14a2489 | ||
|
|
a03b08bb4c | ||
|
|
98086e4fdc |
@@ -1,12 +0,0 @@
|
||||
version = 1
|
||||
|
||||
test_patterns = ["tests/test_*.py"]
|
||||
|
||||
exclude_patterns = ["examples/**"]
|
||||
|
||||
[[analyzers]]
|
||||
name = "python"
|
||||
enabled = true
|
||||
|
||||
[analyzers.meta]
|
||||
runtime_version = "3.x.x"
|
||||
62
.github/stale.yml
vendored
Normal file
62
.github/stale.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
# Configuration for probot-stale - https://github.com/probot/stale
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 60
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
|
||||
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
|
||||
daysUntilClose: 7
|
||||
|
||||
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels: []
|
||||
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
exemptLabels:
|
||||
- bug
|
||||
- pinned
|
||||
- security
|
||||
- "[Status] Maybe Later"
|
||||
|
||||
# Set to true to ignore issues in a project (defaults to false)
|
||||
exemptProjects: false
|
||||
|
||||
# Set to true to ignore issues in a milestone (defaults to false)
|
||||
exemptMilestones: false
|
||||
|
||||
# Set to true to ignore issues with an assignee (defaults to false)
|
||||
exemptAssignees: false
|
||||
|
||||
# Label to use when marking as stale
|
||||
staleLabel: wontfix
|
||||
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you
|
||||
for your contributions.
|
||||
|
||||
# Comment to post when removing the stale label.
|
||||
# unmarkComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Comment to post when closing a stale Issue or Pull Request.
|
||||
# closeComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Limit the number of actions per hour, from 1-30. Default is 30
|
||||
limitPerRun: 30
|
||||
|
||||
# Limit to only `issues` or `pulls`
|
||||
# only: issues
|
||||
|
||||
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
|
||||
# pulls:
|
||||
# daysUntilStale: 30
|
||||
# markComment: >
|
||||
# This pull request has been automatically marked as stale because it has not had
|
||||
# recent activity. It will be closed if no further activity occurs. Thank you
|
||||
# for your contributions.
|
||||
|
||||
# issues:
|
||||
# exemptLabels:
|
||||
# - confirmed
|
||||
24
.github/workflows/stale.yml
vendored
24
.github/workflows/stale.yml
vendored
@@ -1,24 +0,0 @@
|
||||
name: Mark stale issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0/3 * * *"
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days'
|
||||
stale-pr-message: 'This PR is stale because it has been open for a year with no activity. Remove the stale label or comment on the PR otherwise this will be closed in 5 days'
|
||||
stale-issue-label: 'stale'
|
||||
stale-pr-label: 'stale'
|
||||
days-before-stale: 90
|
||||
days-before-close: 5
|
||||
operations-per-run: 100
|
||||
exempt-issue-labels: 'bug,enhancement'
|
||||
remove-stale-when-updated: true
|
||||
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -39,11 +39,9 @@ jobs:
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe -m pip install numpy==1.19.5
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml numpy --user
|
||||
$CONDA\\python.exe -m pip install pyqlib --ignore-installed ruamel.yaml --user
|
||||
else
|
||||
sudo $CONDA/bin/python -m pip install numpy==1.19.5
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml numpy
|
||||
sudo $CONDA/bin/python -m pip install pyqlib --ignore-installed ruamel.yaml
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -34,7 +34,3 @@ tags
|
||||
|
||||
.pytest_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
|
||||
73
README.md
73
README.md
@@ -7,20 +7,6 @@
|
||||
[](LICENSE)
|
||||
[](https://gitter.im/Microsoft/qlib?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
## :newspaper: **What's NEW!** :sparkling_heart:
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Online serving and automatic model rolling | :star: [Released](https://github.com/microsoft/qlib/pull/290) on May 17, 2021 |
|
||||
| DoubleEnsemble Model | [Released](https://github.com/microsoft/qlib/pull/286) on Mar 2, 2021 |
|
||||
| High-frequency data processing example | [Released](https://github.com/microsoft/qlib/pull/257) on Feb 5, 2021 |
|
||||
| High-frequency trading example | [Part of code released](https://github.com/microsoft/qlib/pull/227) on Jan 28, 2021 |
|
||||
| High-frequency data(1min) | [Released](https://github.com/microsoft/qlib/pull/221) on Jan 27, 2021 |
|
||||
| Tabnet Model | [Released](https://github.com/microsoft/qlib/pull/205) on Jan 22, 2021 |
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
|
||||
@@ -31,11 +17,10 @@ Qlib is an AI-oriented quantitative investment platform, which aims to realize t
|
||||
|
||||
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
|
||||
|
||||
With Qlib, users can easily try ideas to create better Quant investment strategies.
|
||||
With Qlib, user can easily try ideas to create better Quant investment strategies.
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
- [**Plans**](#plans)
|
||||
- [Framework of Qlib](#framework-of-qlib)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Installation](#installation)
|
||||
@@ -46,24 +31,14 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
- [High-frequency execution](#high-frequency-execution)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contact Us](#contact-us)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
|
||||
# Plans
|
||||
New features under development(order by estimated release time).
|
||||
Your feedbacks about the features are very important.
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
|
||||
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
|
||||
| Meta-Learning-based data selection | Initial opensource version under development |
|
||||
|
||||
# Framework of Qlib
|
||||
|
||||
@@ -72,11 +47,11 @@ Your feedbacks about the features are very important.
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
@@ -144,20 +119,14 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
## Data Preparation
|
||||
Load and prepare data by running the following code:
|
||||
```bash
|
||||
# get 1d data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# get 1min data
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
|
||||
|
||||
```
|
||||
|
||||
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in
|
||||
the same repository.
|
||||
Users could create the same dataset with it.
|
||||
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
@@ -245,10 +214,9 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
- Rank Label
|
||||

|
||||
-->
|
||||
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
|
||||
|
||||
## Building Customized Quant Research Workflow by Code
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
@@ -265,7 +233,6 @@ Here is a list of models built on `Qlib`.
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -275,10 +242,10 @@ The performance of each model on the `Alpha158` and `Alpha360` dataset can be fo
|
||||
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
|
||||
|
||||
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
|
||||
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
@@ -304,6 +271,14 @@ Dataset plays a very important role in Quant. Here is a list of the datasets bui
|
||||
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
# High-Frequency Execution
|
||||
High-frequency order execution is a fundamental problem in quantitative finance.
|
||||
It aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument.
|
||||
AI has the potential to mine patterns from a huge mass of high-frequency market data and helps traders make better decisions during order execution.
|
||||
Here is a list of solutions built on `Qlib`.
|
||||
- [Universal Trading for Order Execution with Oracle Policy Distillation](examples/trade/)
|
||||
|
||||
|
||||
# More About Qlib
|
||||
The detailed documents are organized in [docs](docs/).
|
||||
[Sphinx](http://www.sphinx-doc.org) and the readthedocs theme is required to build the documentation in html formats.
|
||||
@@ -341,27 +316,17 @@ which creates a dataset (14 features/factors) from the basic OHLCV daily data of
|
||||
* `+(-)E` indicates with (out) `ExpressionCache`
|
||||
* `+(-)D` indicates with (out) `DatasetCache`
|
||||
|
||||
Most general-purpose databases take too much time to load data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
|
||||
Most general-purpose databases take too much time on loading data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
|
||||
Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
|
||||
# Related Reports
|
||||
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
# Contact Us
|
||||
- If you have any issues, please create issue [here](https://github.com/microsoft/qlib/issues/new/choose) or send messages in [gitter](https://gitter.im/Microsoft/qlib).
|
||||
- If you want to make contributions to `Qlib`, please [create pull requests](https://github.com/microsoft/qlib/compare).
|
||||
- For other reasons, you are welcome to contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)).
|
||||
- We are recruiting new members(both FTEs and interns), your resumes are welcome!
|
||||
|
||||
Join IM discussion groups:
|
||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||
|----|
|
||||
||
|
||||
|
||||
# Contributing
|
||||
|
||||
|
||||
@@ -70,31 +70,3 @@ If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If
|
||||
|
||||
|
||||
Also, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.
|
||||
|
||||
3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
File "qlib/qlib/__init__.py", line 19, in init
|
||||
from .data.cache import H
|
||||
File "qlib/qlib/data/__init__.py", line 8, in <module>
|
||||
from .data import (
|
||||
File "qlib/qlib/data/data.py", line 20, in <module>
|
||||
from .cache import H
|
||||
File "qlib/qlib/data/cache.py", line 36, in <module>
|
||||
from .ops import Operators
|
||||
File "qlib/qlib/data/ops.py", line 19, in <module>
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with ``PyCharm`` IDE, users can execute the following command in the project root folder to compile Cython files and generate executable files:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
BIN
docs/_static/img/online_serving.png
vendored
BIN
docs/_static/img/online_serving.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 440 KiB |
BIN
docs/_static/img/qrcode/gitter_qr.png
vendored
BIN
docs/_static/img/qrcode/gitter_qr.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 7.2 KiB |
@@ -1,45 +0,0 @@
|
||||
.. _serial:
|
||||
|
||||
=================================
|
||||
Serialization
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
``Qlib`` supports dumping the state of ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc. into a disk and reloading them.
|
||||
|
||||
Serializable Class
|
||||
========================
|
||||
|
||||
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
|
||||
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
|
||||
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
|
||||
|
||||
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
|
||||
|
||||
Example
|
||||
==========================
|
||||
``Qlib``'s serializable class includes ``DataHandler``, ``DataSet``, ``Processor`` and ``Model``, etc., which are subclass of ``qlib.utils.serial.Serializable``.
|
||||
Specifically, ``qlib.data.dataset.DatasetH`` is one of them. Users can serialize ``DatasetH`` as follows.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
##=============dump dataset=============
|
||||
dataset.to_pickle(path="dataset.pkl") # dataset is an instance of qlib.data.dataset.DatasetH
|
||||
|
||||
##=============reload dataset=============
|
||||
with open("dataset.pkl", "rb") as file_dataset:
|
||||
dataset = pickle.load(file_dataset)
|
||||
|
||||
.. note::
|
||||
Only state of ``DatasetH`` should be saved on the disk, such as some `mean` and `variance` used for data normalization, etc.
|
||||
|
||||
After reloading the ``DatasetH``, users need to reinitialize it. It means that users can reset some states of ``DatasetH`` or ``QlibDataHandler`` such as `instruments`, `start_time`, `end_time` and `segments`, etc., and generate new data according to the states (data is not state and should not be saved on the disk).
|
||||
|
||||
A more detailed example is in this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
|
||||
|
||||
|
||||
API
|
||||
===================
|
||||
Please refer to `Serializable API <../reference/api.html#module-qlib.utils.serial.Serializable>`_.
|
||||
@@ -1,89 +0,0 @@
|
||||
.. _task_management:
|
||||
|
||||
=================================
|
||||
Task Management
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
|
||||
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
|
||||
Task Generating
|
||||
===============
|
||||
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
|
||||
The specific task template can be viewed in
|
||||
`Task Section <../component/workflow.html#task-section>`_.
|
||||
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
|
||||
|
||||
Here is the base class of ``TaskGen``:
|
||||
|
||||
.. autoclass:: qlib.workflow.task.gen.TaskGen
|
||||
:members:
|
||||
|
||||
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
|
||||
|
||||
Task Storing
|
||||
===============
|
||||
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.config import C
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
|
||||
"task_db_name" : "rolling_db" # database name
|
||||
}
|
||||
|
||||
.. autoclass:: qlib.workflow.task.manage.TaskManager
|
||||
:members:
|
||||
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
|
||||
|
||||
Task Training
|
||||
===============
|
||||
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
|
||||
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
|
||||
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
|
||||
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
|
||||
|
||||
.. autofunction:: qlib.workflow.task.manage.run_task
|
||||
|
||||
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
|
||||
|
||||
.. autoclass:: qlib.model.trainer.Trainer
|
||||
:members:
|
||||
|
||||
``Trainer`` will train a list of tasks and return a list of model recorders.
|
||||
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
|
||||
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
|
||||
|
||||
Task Collecting
|
||||
===============
|
||||
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
|
||||
|
||||
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
|
||||
|
||||
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
|
||||
|
||||
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object
|
||||
|
||||
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
|
||||
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
@@ -31,7 +31,7 @@ Qlib Format Data
|
||||
We've specially designed a data structure to manage financial data, please refer to the `File storage design section in Qlib paper <https://arxiv.org/abs/2009.11189>`_ for detailed information.
|
||||
Such data will be stored with filename suffix `.bin` (We'll call them `.bin` file, `.bin` format, or qlib format). `.bin` file is designed for scientific computing on finance data.
|
||||
|
||||
``Qlib`` provides two different off-the-shelf datasets, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
|
||||
``Qlib`` provides two different off-the-shelf dataset, which can be accessed through this `link <https://github.com/microsoft/qlib/blob/main/qlib/contrib/data/handler.py>`_:
|
||||
|
||||
======================== ================= ================
|
||||
Dataset US Market China Market
|
||||
@@ -41,7 +41,6 @@ Alpha360 √ √
|
||||
Alpha158 √ √
|
||||
======================== ================= ================
|
||||
|
||||
Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency dataset example through this `link <https://github.com/microsoft/qlib/tree/main/examples/highfreq>`_.
|
||||
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
@@ -49,19 +48,15 @@ Qlib Format Dataset
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# download 1d
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# download 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
|
||||
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/qlib_data/cn_data`` directory and ``~/.qlib/qlib_data/us_data`` directory respectively.
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
|
||||
|
||||
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
|
||||
|
||||
@@ -72,19 +67,12 @@ Converting CSV Format into Qlib Format
|
||||
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
|
||||
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
|
||||
Here are some example:
|
||||
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
|
||||
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
@@ -152,16 +140,6 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
|
||||
Stock Pool (Market)
|
||||
--------------------------------
|
||||
|
||||
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments
|
||||
|
||||
|
||||
Multiple Stock Modes
|
||||
--------------------------------
|
||||
|
||||
@@ -180,7 +158,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
- If users use ``Qlib`` in china-stock mode, china-stock data is required. Users can use ``Qlib`` in china-stock mode according to the following steps:
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in china-stock mode
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/csv_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -189,9 +167,9 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
|
||||
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in US-stock mode
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -199,11 +177,6 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
|
||||
|
||||
|
||||
Data API
|
||||
========================
|
||||
|
||||
@@ -240,25 +213,6 @@ Filter
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
|
||||
Here is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
filter: &filter
|
||||
filter_type: ExpressionDFilter
|
||||
rule_expression: "Ref($close, -2) / Ref($close, -1) > 1"
|
||||
filter_start_time: 2010-01-01
|
||||
filter_end_time: 2010-01-07
|
||||
keep: False
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2010-01-01
|
||||
end_time: 2021-01-22
|
||||
fit_start_time: 2010-01-01
|
||||
fit_end_time: 2015-12-31
|
||||
instruments: *market
|
||||
filter_pipe: [*filter]
|
||||
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
Reference
|
||||
@@ -320,10 +274,9 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
|
||||
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
|
||||
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
|
||||
|
||||
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
|
||||
|
||||
|
||||
Processor
|
||||
@@ -360,6 +313,7 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
|
||||
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
|
||||
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
import qlib
|
||||
@@ -386,9 +340,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
# fetch all the features
|
||||
print(h.fetch(col_set="feature"))
|
||||
|
||||
|
||||
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
@@ -413,7 +364,8 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
|
||||
API
|
||||
---------
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
|
||||
|
||||
|
||||
|
||||
Cache
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
.. _online:
|
||||
|
||||
=================================
|
||||
Online Serving
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
.. image:: ../_static/img/online_serving.png
|
||||
:align: center
|
||||
|
||||
|
||||
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
|
||||
``Online Serving`` is a set of modules for online models using the latest data,
|
||||
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
|
||||
|
||||
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
|
||||
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
Updater
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
@@ -34,7 +34,6 @@ Here is a general view of the structure of the system:
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
|
||||
@@ -95,52 +94,6 @@ The ``RecordTemp`` class is a class that enables generate experiment results suc
|
||||
|
||||
- ``SignalRecord``: This class generates the `prediction` results of the model.
|
||||
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
|
||||
|
||||
Here is a simple example of what is done in ``SigAnaRecord``, which users can refer to if they want to calculate IC, Rank IC, Long-Short Return with their own prediction and label.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
|
||||
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
Here is a simple exampke of what is done in ``PortAnaRecord``, which users can refer to if they want to do backtest based on their own prediction and label.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
# backtest
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"] - report_normal["cost"])
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
|
||||
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
|
||||
@@ -101,7 +101,7 @@ Graphical Result
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue <data.html#feature>`_ for more details.
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
@@ -111,6 +111,8 @@ Usage & Example
|
||||
pred_score, strategy=strategy, **BACKTEST_CONFIG
|
||||
)
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
|
||||
|
||||
@@ -90,12 +90,12 @@ Below is a typical config file of ``qrun``.
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
|
||||
|
||||
|
||||
@@ -42,7 +42,6 @@ Document Structure
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
@@ -50,8 +49,6 @@ Document Structure
|
||||
|
||||
Building Formulaic Alphas <advanced/alpha.rst>
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
Task Management <advanced/task_management.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -53,34 +53,6 @@ Cache
|
||||
.. autoclass:: qlib.data.cache.DiskDatasetCache
|
||||
:members:
|
||||
|
||||
|
||||
Storage
|
||||
-------------
|
||||
.. autoclass:: qlib.data.storage.storage.BaseStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.CalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.InstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.storage.FeatureStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileStorageMixin
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileCalendarStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileInstrumentStorage
|
||||
:members:
|
||||
|
||||
.. autoclass:: qlib.data.storage.file_storage.FileFeatureStorage
|
||||
:members:
|
||||
|
||||
|
||||
Dataset
|
||||
---------------
|
||||
|
||||
@@ -180,81 +152,4 @@ Recorder
|
||||
Record Template
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.record_temp
|
||||
:members:
|
||||
|
||||
Task Management
|
||||
====================
|
||||
|
||||
|
||||
TaskGen
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.gen
|
||||
:members:
|
||||
|
||||
TaskManager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.manage
|
||||
:members:
|
||||
|
||||
Trainer
|
||||
--------------------
|
||||
.. automodule:: qlib.model.trainer
|
||||
:members:
|
||||
|
||||
Collector
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.collect
|
||||
:members:
|
||||
|
||||
Group
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.group
|
||||
:members:
|
||||
|
||||
Ensemble
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.ensemble
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.utils
|
||||
:members:
|
||||
|
||||
|
||||
Online Serving
|
||||
====================
|
||||
|
||||
|
||||
Online Manager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
|
||||
|
||||
Utils
|
||||
====================
|
||||
|
||||
Serializable
|
||||
--------------------
|
||||
|
||||
.. automodule:: qlib.utils.serial.Serializable
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
:members:
|
||||
@@ -75,14 +75,3 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"default_exp_name": "Experiment",
|
||||
}
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, you can initialize qlib below
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
@@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
- Override the `finetune` method (Optional)
|
||||
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- The parameters must include the parameter `dataset`.
|
||||
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
|
||||
.. code-block:: Python
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
# DoubleEnsemble
|
||||
* DoubleEnsemble is an ensemble framework leveraging learning trajectory based sample reweighting and shuffling based feature selection, to solve both the low signal-to-noise ratio and increasing number of features problems. They identify the key samples based on the training dynamics on each sample and elicit key features based on the ablation impact of each feature via shuffling. The model is applicable to a wide range of base models, capable of extracting complex patterns, while mitigating the overfitting and instability issues for financial market prediction.
|
||||
* This code used in Qlib is implemented by ourselves.
|
||||
* Paper: DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis [https://arxiv.org/pdf/2010.01265.pdf](https://arxiv.org/pdf/2010.01265.pdf).
|
||||
@@ -1,3 +0,0 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
@@ -1,90 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 28
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,97 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 136
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -29,7 +29,7 @@ data_handler_config: &data_handler_config
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
instruments: *market
|
||||
data_loader:
|
||||
class: QlibDataLoader
|
||||
kwargs:
|
||||
config:
|
||||
feature:
|
||||
- ["Resi($close, 15)/$close", "Std(Abs($close/Ref($close, 1)-1)*$volume, 5)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, 5)+1e-12)", "Rsquare($close, 5)", "($high-$low)/$open", "Rsquare($close, 10)", "Corr($close, Log($volume+1), 5)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 5)", "Corr($close, Log($volume+1), 10)", "Rsquare($close, 20)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 60)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 10)", "Corr($close, Log($volume+1), 20)", "(Less($open, $close)-$low)/$open"]
|
||||
- ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
label:
|
||||
- ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
- ["LABEL0"]
|
||||
freq: day
|
||||
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSZScoreNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandlerLP
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -16,8 +16,6 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -27,13 +25,11 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
"""Returns formatted column definition in order expected by the TFT."""
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
|
||||
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
Binary file not shown.
@@ -44,7 +44,6 @@ task:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 158
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
@@ -56,7 +55,7 @@ task:
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 360
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,208 +0,0 @@
|
||||
"""
|
||||
This script is the demonstrating the implementation of Metric Extractor and Detector
|
||||
|
||||
NOTE: A lot of details is not considered in this script
|
||||
- Corner case that will raise error( std == 0)
|
||||
|
||||
|
||||
|
||||
The following functions are used to demonstrate the following examples
|
||||
|
||||
|
||||
· Metric Extractor:
|
||||
case 1) Basic statistics on different slices of the DataFrame df:
|
||||
1) The statistics include:
|
||||
· STD, Mean, Skewnes, Kurtosis
|
||||
2) The above statistics can be calculated on the following data slices:
|
||||
· df.groupby(['datetime'])
|
||||
· df.groupby(['datetime', 'industry' ])
|
||||
3) The statistics could be calculated on the time dimension for each instruments and factor(the factor can be represented by experssion)
|
||||
· <df implemented by expresion>.groupby(['instrument', 'factor'])
|
||||
case 2) Advanced statistics on different slices of the DataFrame df:
|
||||
1) Auto-correlation:
|
||||
· Calculate corr(df.loc[t, :, :], df.loc[t-w, :, :]), w=1, 2, ….
|
||||
2) Correlation between factors:
|
||||
· For any pair of factors (i, j): calculate corr(df.loc[t, :, i], df.loc[t, :, j]). The result is a correlation matrix with each element corresponds to a correlation value between a pair of factors.
|
||||
|
||||
· Detector: detect the abnormality of the extracted metric;
|
||||
a) Algorithms:
|
||||
§ Basic checks: NaN.
|
||||
§ Point anomaly detection.
|
||||
§ Segment anomaly detection.
|
||||
b) Scenarios:
|
||||
§ Online anomaly detection: monitoring streaming data.
|
||||
The usage of the detectors are demonstrated in the `case_1_*`and `case_2_*`
|
||||
|
||||
|
||||
case 3): Examples to use MetricExt to monitor IC and rank IC
|
||||
1) IC(Information Coefficient) #case_3_1
|
||||
2) RankIC #case_3_2
|
||||
"""
|
||||
|
||||
# AUTO download data
|
||||
from typing import List, Union
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.config import REG_CN
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.data.dataset.loader import QlibDataLoader
|
||||
from qlib.data.monitor.metric import format_conv
|
||||
from qlib.data.monitor.metric import MeanM, SkewM, KurtM, StdM, AutoCM, CorrM
|
||||
from qlib.data.monitor.detector import NDDetector, SWNDD, ThresholdD
|
||||
from qlib.data import D
|
||||
import fire
|
||||
|
||||
UNIVERSE = "csi300"
|
||||
START_TIME = "20200101"
|
||||
|
||||
# ------------------ a helper function to get data to demonstrate the functionality --------------------
|
||||
|
||||
|
||||
def get_data_df(col_idx: Union[int, List[int]] = 0, verbose: bool = True):
|
||||
"""
|
||||
a helper function to get data to demonstrate the functionality.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
col_idx : Union[int, List[int]]
|
||||
column index of the metrics
|
||||
"""
|
||||
dh = Alpha158(instruments=UNIVERSE, infer_processors=[], learn_processors=[], start_time=START_TIME)
|
||||
df = dh.fetch()
|
||||
|
||||
if verbose:
|
||||
print(df.head())
|
||||
|
||||
# We don't have industries in dataframe, we generate the with fake data
|
||||
industry = pd.Series(df.index.get_level_values("instrument").str.slice(stop=2).to_list(), index=df.index)
|
||||
|
||||
# select a factor
|
||||
factor_df = format_conv(df.iloc[:, col_idx], industry=industry)
|
||||
if verbose:
|
||||
print(f"Selected metric: {df.columns[col_idx]}")
|
||||
print(factor_df)
|
||||
return factor_df
|
||||
|
||||
|
||||
def get_target(horizon=5):
|
||||
target = f"Ref($close, -{horizon + 1})/Ref($close, -1) - 1" # There are lots of targets: return is one of them
|
||||
qdl = QlibDataLoader(config=([target], ["target"]))
|
||||
df = qdl.load(instruments=UNIVERSE, start_time=START_TIME) # Aligning with factor will improve performance
|
||||
df = format_conv(df["target"])
|
||||
return df
|
||||
|
||||
|
||||
# ----------------- Cases to demonstrate the usage of detector and examples ----------------------
|
||||
|
||||
|
||||
def case_1_1():
|
||||
factor_df = get_data_df()
|
||||
# 1) Extract metrics
|
||||
|
||||
# 1.1) df.groupby(["datetime"])
|
||||
mtrc = MeanM()
|
||||
m_mean = mtrc.extract(factor_df)
|
||||
print(m_mean)
|
||||
|
||||
ndd = NDDetector()
|
||||
ndd.fit(m_mean) # use historical data to fit detector
|
||||
check_res = ndd.check(m_mean)
|
||||
print(check_res) # detecting on new data or historical data
|
||||
print(check_res.value_counts())
|
||||
|
||||
|
||||
def case_1_2():
|
||||
factor_df = get_data_df()
|
||||
# 1.2) df.groupby("datetime", "industry")
|
||||
mtrc = MeanM(group=["industry"])
|
||||
m_multi = mtrc.extract(factor_df)
|
||||
print(m_multi)
|
||||
|
||||
for col_name, s in m_multi.iteritems():
|
||||
print(col_name)
|
||||
ndd = NDDetector()
|
||||
ndd.fit(s) # use historical data to fit detector
|
||||
check_res = ndd.check(s)
|
||||
print(check_res) # detecting on new data or historical data
|
||||
print(check_res.value_counts())
|
||||
|
||||
|
||||
def case_1_3():
|
||||
# case 1.3
|
||||
# factor_df = get_data_df()
|
||||
qdl = QlibDataLoader(config=(["$close/Ref($close, 1) - 1"], ["return"]))
|
||||
df = qdl.load(instruments=["SH600519"], start_time=START_TIME)
|
||||
df = format_conv(df)
|
||||
s = df.iloc[:, 0]
|
||||
print(s)
|
||||
dtc = SWNDD(window=20)
|
||||
dtc.fit(s) # fit use historical data (TODO: updating will be supported in the future)
|
||||
check_res = dtc.check(s) #
|
||||
print(check_res)
|
||||
print(check_res.value_counts())
|
||||
print(check_res[check_res])
|
||||
|
||||
|
||||
def case_2_1():
|
||||
# · Calculate corr(df.loc[t, :, :], df.loc[t-w, :, :]), w=1, 2, ….
|
||||
factor_df = get_data_df()
|
||||
acm = AutoCM()
|
||||
mtrc = acm.extract(factor_df)
|
||||
print(mtrc)
|
||||
|
||||
thd = ThresholdD(0.0, reverse=True)
|
||||
check_res = thd.check(mtrc)
|
||||
|
||||
print(check_res)
|
||||
print(check_res.value_counts())
|
||||
|
||||
|
||||
def case_2_2():
|
||||
factor_df1, factor_df2 = get_data_df(0), get_data_df(1)
|
||||
|
||||
cm = CorrM()
|
||||
mtrc = cm.extract(factor_df1, factor_df2)
|
||||
print(mtrc)
|
||||
|
||||
thd = ThresholdD(0.0, reverse=True)
|
||||
check_res = thd.check(mtrc)
|
||||
|
||||
print(check_res)
|
||||
print(check_res.value_counts())
|
||||
|
||||
|
||||
def case_3_1_3_2():
|
||||
target, factor = get_target(), get_data_df(0)
|
||||
ic_m, rank_ic_m = CorrM(), CorrM(mode="spearman")
|
||||
ic, rank_ic = ic_m.extract(factor, target), rank_ic_m.extract(factor, target)
|
||||
print(pd.DataFrame({"ic": ic, "rank_ic": rank_ic}))
|
||||
|
||||
|
||||
def run(test_list=["case_1_1", "case_1_2", "case_1_3", "case_2_1", "case_2_2", "case_3_1_3_2"]):
|
||||
"""
|
||||
run the specific tests
|
||||
|
||||
python monitor.py case_3_1_3_2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
test_list : str[]
|
||||
The tests to run
|
||||
"""
|
||||
if isinstance(test_list, str):
|
||||
test_list = [test_list]
|
||||
for fn in test_list:
|
||||
globals()[fn]()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
qlib.init()
|
||||
fire.Fire(run)
|
||||
@@ -1,130 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e62a81e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from tqdm.auto import tqdm\n",
|
||||
"%matplotlib inline\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c503217b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.data.monitor.analyser import Analyser\n",
|
||||
"import qlib\n",
|
||||
"qlib.init()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9c276470",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SimpleDFA(Analyser):\n",
|
||||
" \"\"\"Simple (D)ata(F)rame (A)nalyser\"\"\"\n",
|
||||
" def analyse(self, data: pd.DataFrame, *args, **kwargs):\n",
|
||||
" data.plot(*args, **kwargs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "110262e4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from monitor import get_data_df, AutoCM"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0ea38c62",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# get data\n",
|
||||
"factor_df = get_data_df([1], verbose=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dbded6fe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# metric extractor\n",
|
||||
"acm = AutoCM()\n",
|
||||
"mtrc = acm.extract(factor_df)\n",
|
||||
"print(mtrc)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "65517c81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Analyser\n",
|
||||
"sa = SimpleDFA()\n",
|
||||
"sa.analyse(mtrc, title='Auto Correlation')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dab6fb2e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -25,11 +25,4 @@ The example is given in `workflow.py`, users can run the code as follows.
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
```
|
||||
@@ -1,13 +1,24 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
import fire
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.utils import init_instance_by_config, exists_qlib_data
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.ops import Operators
|
||||
from qlib.data.data import Cal
|
||||
@@ -16,16 +27,17 @@ from qlib.tests.data import GetData
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow:
|
||||
class HighfreqWorkflow(object):
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2020-09-15 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
train_end_time = "2020-11-30 16:00:00"
|
||||
test_start_time = "2020-12-01 00:00:00"
|
||||
start_time = pd.Timestamp("2020-09-15 00:00:00")
|
||||
end_time = pd.Timestamp("2021-01-18 16:00:00")
|
||||
train_end_time = pd.Timestamp("2020-11-30 16:00:00")
|
||||
test_start_time = pd.Timestamp("2020-12-01 00:00:00")
|
||||
|
||||
DATA_HANDLER_CONFIG0 = {
|
||||
"start_time": start_time,
|
||||
@@ -85,7 +97,9 @@ class HighfreqWorkflow:
|
||||
# use yahoo_cn_1min data
|
||||
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
|
||||
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
|
||||
qlib.init(**QLIB_INIT_CONFIG)
|
||||
|
||||
def _prepare_calender_cache(self):
|
||||
@@ -132,44 +146,72 @@ class HighfreqWorkflow:
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
dataset.init(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_LS,
|
||||
},
|
||||
)
|
||||
dataset_backtest.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.init(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.setup_data(handler_kwargs={})
|
||||
|
||||
##=============get data=============
|
||||
xtest = dataset.prepare("test")
|
||||
backtest_test = dataset_backtest.prepare("test")
|
||||
xtest = dataset.prepare(["test"])
|
||||
backtest_test = dataset_backtest.prepare(["test"])
|
||||
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
|
||||
def get_high_freq_data(self, data_path):
|
||||
self._init_qlib()
|
||||
self._prepare_calender_cache()
|
||||
|
||||
import os
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
xtrain, xtest = dataset.prepare(["train", "test"])
|
||||
normed_feature = pd.concat([xtrain, xtest]).sort_index()
|
||||
dic = dict(tuple(normed_feature.groupby("instrument")))
|
||||
feature_path = os.path.join(data_path, "normed_feature/")
|
||||
if not os.path.exists(feature_path):
|
||||
os.makedirs(feature_path)
|
||||
for k, v in dic.items():
|
||||
v.to_pickle(feature_path + f"{k}.pkl")
|
||||
|
||||
|
||||
dataset_backtest = init_instance_by_config(self.task["dataset_backtest"])
|
||||
backtest_train, backtest_test = dataset_backtest.prepare(["train", "test"])
|
||||
backtest = pd.concat([backtest_train, backtest_test]).sort_index()
|
||||
backtest['date'] = backtest.index.map(lambda x: x[1].date())
|
||||
backtest.set_index('date', append=True, drop=True, inplace=True)
|
||||
dic = dict(tuple(backtest.groupby("instrument")))
|
||||
backtest_path = os.path.join(data_path, "backtest/")
|
||||
if not os.path.exists(backtest_path):
|
||||
os.makedirs(backtest_path)
|
||||
for k, v in dic.items():
|
||||
v.to_pickle(backtest_path + f"{k}.pkl.backtest")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(HighfreqWorkflow)
|
||||
#fire.Fire(HighfreqWorkflow)
|
||||
data_path = '../data/'
|
||||
workflow = HighfreqWorkflow()
|
||||
workflow.get_high_freq_data(data_path)
|
||||
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
market: &market 'csi300'
|
||||
start_time: &start_time "2020-09-15 00:00:00"
|
||||
end_time: &end_time "2021-01-18 16:00:00"
|
||||
train_end_time: &train_end_time "2020-11-15 16:00:00"
|
||||
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
|
||||
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
|
||||
test_start_time: &test_start_time "2020-12-01 00:00:00"
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: *start_time
|
||||
end_time: *end_time
|
||||
fit_start_time: *start_time
|
||||
fit_end_time: *train_end_time
|
||||
instruments: *market
|
||||
freq: '1min'
|
||||
infer_processors:
|
||||
- class: 'RobustZScoreNorm'
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
clip_outlier: false
|
||||
- class: "Fillna"
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
learn_processors:
|
||||
- class: 'DropnaLabel'
|
||||
- class: 'CSRankNorm'
|
||||
kwargs:
|
||||
fields_group: 'label'
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
task:
|
||||
model:
|
||||
class: "HFLGBModel"
|
||||
module_path: "qlib.contrib.model.highfreq_gdbt_model"
|
||||
kwargs:
|
||||
objective: 'binary'
|
||||
metric: ['binary_logloss','auc']
|
||||
verbosity: -1
|
||||
learning_rate: 0.01
|
||||
max_depth: 8
|
||||
num_leaves: 150
|
||||
lambda_l1: 1.5
|
||||
lambda_l2: 1
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: "DatasetH"
|
||||
module_path: "qlib.data.dataset"
|
||||
kwargs:
|
||||
handler:
|
||||
class: "Alpha158"
|
||||
module_path: "qlib.contrib.data.handler"
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [*start_time, *train_end_time]
|
||||
valid: [*train_end_time, *valid_end_time]
|
||||
test: [*test_start_time, *end_time]
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
@@ -1,23 +0,0 @@
|
||||
# LightGBM hyperparameter
|
||||
|
||||
## Alpha158
|
||||
First terminal
|
||||
```
|
||||
optuna create-study --study LGBM_158 --storage sqlite:///db.sqlite3
|
||||
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
|
||||
```
|
||||
Second terminal
|
||||
```
|
||||
python hyperparameter_158.py
|
||||
```
|
||||
|
||||
## Alpha360
|
||||
First terminal
|
||||
```
|
||||
optuna create-study --study LGBM_360 --storage sqlite:///db.sqlite3
|
||||
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
|
||||
```
|
||||
Second terminal
|
||||
```
|
||||
python hyperparameter_360.py
|
||||
```
|
||||
@@ -1,46 +0,0 @@
|
||||
import qlib
|
||||
import optuna
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.config import CSI300_DATASET_CONFIG
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
def objective(trial):
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
|
||||
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
|
||||
"subsample": trial.suggest_uniform("subsample", 0, 1),
|
||||
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
|
||||
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
|
||||
"max_depth": 10,
|
||||
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
|
||||
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
|
||||
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
|
||||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
|
||||
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
|
||||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
|
||||
},
|
||||
},
|
||||
}
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
dataset = init_instance_by_config(CSI300_DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
@@ -1,49 +0,0 @@
|
||||
import qlib
|
||||
import optuna
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS
|
||||
|
||||
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)
|
||||
|
||||
|
||||
def objective(trial):
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
|
||||
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
|
||||
"subsample": trial.suggest_uniform("subsample", 0, 1),
|
||||
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
|
||||
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
|
||||
"max_depth": 10,
|
||||
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
|
||||
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
|
||||
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
|
||||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
|
||||
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
|
||||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
dataset = init_instance_by_config(DATASET_CONFIG)
|
||||
|
||||
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
@@ -1,5 +0,0 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
optuna==2.7.0
|
||||
optuna-dashboard==0.4.1
|
||||
@@ -1,32 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
# get model feature importance
|
||||
feature_importance = model.get_feature_importance()
|
||||
print("feature importance:")
|
||||
print(feature_importance)
|
||||
@@ -1,105 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingTaskExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region=REG_CN,
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_config=None,
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
):
|
||||
# TaskManager config
|
||||
if task_config is None:
|
||||
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def task_generating(self):
|
||||
print("========== task_generating ==========")
|
||||
tasks = task_generator(
|
||||
tasks=self.task_config,
|
||||
generators=self.rolling_gen, # generate different date segments
|
||||
)
|
||||
pprint(tasks)
|
||||
return tasks
|
||||
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
collector = RecorderCollector(
|
||||
experiment=self.experiment_name,
|
||||
process_list=RollingGroup(),
|
||||
rec_key_func=rec_key,
|
||||
rec_filter_func=my_filter,
|
||||
)
|
||||
print(collector())
|
||||
|
||||
def main(self):
|
||||
self.reset()
|
||||
tasks = self.task_generating()
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python task_manager_rolling.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(RollingTaskExample)
|
||||
@@ -1,92 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=None,
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
|
||||
Args:
|
||||
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
|
||||
region (str, optional): the stock region. Defaults to "cn".
|
||||
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
|
||||
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
|
||||
task_db_name (str, optional): database name. Defaults to "rolling_db".
|
||||
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
|
||||
rolling_step (int, optional): the step for rolling. Defaults to 80.
|
||||
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
begin_time=self.start_time,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this to run all workflow automatically
|
||||
def main(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== simulate ==========")
|
||||
self.rolling_online_manager.simulate(end_time=self.end_time)
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
|
||||
fire.Fire(OnlineSimulationExample)
|
||||
@@ -1,130 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineManager works with rolling tasks.
|
||||
There are four parts including first train, routine 1, add strategy and routine 2.
|
||||
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
|
||||
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
|
||||
Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.tasks = tasks
|
||||
self.add_tasks = add_tasks
|
||||
self.rolling_step = rolling_step
|
||||
strategies = []
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def first_run(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== first_run ==========")
|
||||
self.rolling_online_manager.first_train()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def routine(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== routine ==========")
|
||||
self.rolling_online_manager.routine()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def add_strategy(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== add strategy ==========")
|
||||
strategies = []
|
||||
for task in self.add_tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
self.rolling_online_manager.add_strategy(strategies=strategies)
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
self.routine()
|
||||
self.add_strategy()
|
||||
self.routine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python rolling_online_management.py first_run
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python rolling_online_management.py routine
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
@@ -1,54 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineTool works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
import copy
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
from qlib.tests.config import CSI300_GBDT_TASK
|
||||
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
}
|
||||
|
||||
|
||||
class UpdatePredExample:
|
||||
def __init__(
|
||||
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
|
||||
):
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
self.experiment_name = experiment_name
|
||||
self.online_tool = OnlineToolR(self.experiment_name)
|
||||
self.task_config = task_config
|
||||
|
||||
def first_train(self):
|
||||
rec = task_train(self.task_config, experiment_name=self.experiment_name)
|
||||
self.online_tool.reset_online_tag(rec) # set to online model
|
||||
|
||||
def update_online_pred(self):
|
||||
self.online_tool.update_online_pred()
|
||||
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.update_online_pred()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
## to update online predictions once a day, use the command below
|
||||
# python update_online_pred.py update_online_pred
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(UpdatePredExample)
|
||||
@@ -1,17 +0,0 @@
|
||||
# Rolling Process Data
|
||||
|
||||
This workflow is an example for `Rolling Process Data`.
|
||||
|
||||
## Background
|
||||
|
||||
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
|
||||
|
||||
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
|
||||
|
||||
|
||||
## Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py rolling_process
|
||||
```
|
||||
@@ -1,32 +0,0 @@
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.loader import DataLoaderDH
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
|
||||
|
||||
class RollingDataHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
data_loader_kwargs={},
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "DataLoaderDH",
|
||||
"kwargs": {**data_loader_kwargs},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
)
|
||||
@@ -1,137 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class RollingDataWorkflow:
|
||||
|
||||
MARKET = "csi300"
|
||||
start_time = "2010-01-01"
|
||||
end_time = "2019-12-31"
|
||||
rolling_cnt = 5
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def _dump_pre_handler(self, path):
|
||||
handler_config = {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"instruments": self.MARKET,
|
||||
"infer_processors": [],
|
||||
"learn_processors": [],
|
||||
},
|
||||
}
|
||||
pre_handler = init_instance_by_config(handler_config)
|
||||
pre_handler.config(dump_all=True)
|
||||
pre_handler.to_pickle(path)
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
self._init_qlib()
|
||||
self._dump_pre_handler("pre_handler.pkl")
|
||||
pre_handler = self._load_pre_handler("pre_handler.pkl")
|
||||
|
||||
train_start_time = (2010, 1, 1)
|
||||
train_end_time = (2012, 12, 31)
|
||||
valid_start_time = (2013, 1, 1)
|
||||
valid_end_time = (2013, 12, 31)
|
||||
test_start_time = (2014, 1, 1)
|
||||
test_end_time = (2014, 12, 31)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "RollingDataHandler",
|
||||
"module_path": "rolling_handler",
|
||||
"kwargs": {
|
||||
"start_time": datetime(*train_start_time),
|
||||
"end_time": datetime(*test_end_time),
|
||||
"fit_start_time": datetime(*train_start_time),
|
||||
"fit_end_time": datetime(*train_end_time),
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
"data_loader_kwargs": {
|
||||
"handler_config": pre_handler,
|
||||
},
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": (datetime(*train_start_time), datetime(*train_end_time)),
|
||||
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
|
||||
"test": (datetime(*test_start_time), datetime(*test_end_time)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
for rolling_offset in range(self.rolling_cnt):
|
||||
|
||||
print(f"===========rolling{rolling_offset} start===========")
|
||||
if rolling_offset:
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
"processor_kwargs": {
|
||||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
},
|
||||
},
|
||||
segments={
|
||||
"train": (
|
||||
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
),
|
||||
"valid": (
|
||||
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
|
||||
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
|
||||
),
|
||||
"test": (
|
||||
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
|
||||
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
||||
}
|
||||
)
|
||||
|
||||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
||||
print(dtrain, dvalid, dtest)
|
||||
## print or dump data
|
||||
print(f"===========rolling{rolling_offset} end===========")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(RollingDataWorkflow)
|
||||
@@ -5,11 +5,13 @@ import os
|
||||
import sys
|
||||
import fire
|
||||
import time
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
import tempfile
|
||||
import traceback
|
||||
import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
@@ -21,7 +23,8 @@ from pprint import pprint
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.workflow.cli import workflow
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
# init qlib
|
||||
@@ -36,8 +39,12 @@ exp_manager = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
|
||||
# decorator to check the arguments
|
||||
|
||||
104
examples/trade/README.md
Normal file
104
examples/trade/README.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Universal Trading for Order Execution with Oracle Policy Distillation
|
||||
This is the experiment code for our AAAI 2021 paper "[Universal Trading for Order Execution with Oracle Policy Distillation](https://arxiv.org/abs/2103.10860)", including the implementations of all the compared methods in the paper and a general reinforcement learning framework for order execution in quantitative finance.
|
||||
|
||||
## Abstract
|
||||
As a fundamental problem in algorithmic trading, order execution aims at fulfilling a specific trading order, either liquidation or acquirement, for a given instrument. Towards effective execution strategy, recent years have witnessed the shift from the analytical view with model-based market assumptions to model-free perspective, i.e., reinforcement learning, due to its nature of sequential decision optimization. However, the noisy and yet imperfect market information that can be leveraged by the policy has made it quite challenging to build up sample efficient reinforcement learning methods to achieve effective order execution. In this paper, we propose a novel universal trading policy optimization framework to bridge the gap between the noisy yet imperfect market states and the optimal action sequences for order execution. Particularly, this framework leverages a policy distillation method that can better guide the learning of the common policy towards practically optimal execution by an oracle teacher with perfect information to approximate the optimal trading strategy. The extensive experiments have shown significant improvements of our method over various strong baselines, with reasonable trading actions.
|
||||
|
||||
## Environment Dependencies
|
||||
|
||||
### Dependencies
|
||||
|
||||
```
|
||||
gym==0.17.3
|
||||
torch==1.6.0
|
||||
numba==0.51.2
|
||||
numpy==1.19.1
|
||||
pandas==1.1.3
|
||||
tqdm==4.50.2
|
||||
tianshou==0.3.0.post1
|
||||
env==0.1.0
|
||||
PyYAML==5.4.1
|
||||
redis==3.5.3
|
||||
```
|
||||
|
||||
### Environment Variable
|
||||
|
||||
`EXP_PATH` Absolute path to your config folder, we give folder `exp` as an example.
|
||||
|
||||
`OUTPUT_DIR` Absolute path to your log folder.
|
||||
|
||||
## Data Processing
|
||||
|
||||
For Feature processing, we take Yahoo dataset as an example, which can be precessed in `qlib/examples/highfreq/workflow.py` file. If you have a need to change your data storage path, you can change the `data_path` in `workflow.py`, and then do the following.
|
||||
|
||||
```
|
||||
python workflow.py
|
||||
```
|
||||
|
||||
For order generation, if you have changed change the the `data_path` in `workflow.py`, change `data_path` in `order_gen.py` again, then do the following.
|
||||
|
||||
```
|
||||
python order_gen.py
|
||||
```
|
||||
|
||||
## Training and backtest
|
||||
|
||||
### Config file
|
||||
|
||||
Config file is need to start our project, we take `PPO`, `OPDS` and `OPD` as an example in folder `exp/example`. If you want to use our given config, make sure the `data_path` you set before matches the config file.
|
||||
|
||||
### Baseline method
|
||||
|
||||
To run a method, you can do the following.
|
||||
|
||||
```
|
||||
python main.py --config={config_path}
|
||||
```
|
||||
|
||||
Where `{config_path}` means the relative path from your config.yml to `EXP_PATH`.
|
||||
|
||||
If you need to run our given method such as PPO method, you can do the following.
|
||||
|
||||
```
|
||||
python main.py --config=example/PPO/config.yml
|
||||
```
|
||||
|
||||
### OPD method
|
||||
|
||||
OPD method is a multi step method, at first you should run OPDT as the teacher in OPD method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPDT/config.yml
|
||||
```
|
||||
|
||||
After training, find the `policy_best` file in your OPDT log file and copy it to `trade` file for backtest. Also you can change `policy_path` in the `example/OPDT_b/config.yml` to your `policy_best` file. Then run the backtest method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPDT_b/config.yml
|
||||
```
|
||||
|
||||
then processed feature from teacher. Remember to change `log_path` if you have changed `log_dir` in `OPDT_b/config.yml`.
|
||||
|
||||
```
|
||||
python teacher_feature.py
|
||||
```
|
||||
|
||||
and finally start our OPD method.
|
||||
|
||||
```
|
||||
python main.py --config=example/OPD/config.yml
|
||||
```
|
||||
|
||||
## Citation
|
||||
You are more than welcome to citetmu our paper:
|
||||
```
|
||||
@inproceedings{fang2021universal,
|
||||
title={Universal Trading for Order Execution with Oracle Policy Distillation},
|
||||
author={Fang, Yuchen and Ren, Kan and Liu, Weiqing and Zhou, Dong and Zhang, Weinan and Bian, Jiang and Yu, Yong and Liu, Tie-Yan},
|
||||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
|
||||
volume={35},
|
||||
number={1},
|
||||
pages={107--115},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
10
examples/trade/__init__.py
Normal file
10
examples/trade/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# from rl4execution import env, trainer, exploration
|
||||
|
||||
# __all__ = [
|
||||
# "env",
|
||||
# "data",
|
||||
# "utils",
|
||||
# "policy",
|
||||
# "trainer",
|
||||
# "exploration",
|
||||
# ]
|
||||
4
examples/trade/action/__init__.py
Normal file
4
examples/trade/action/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import *
|
||||
from .action_rl import *
|
||||
from .action_rule import *
|
||||
from .action_rl import *
|
||||
27
examples/trade/action/action_rl.py
Normal file
27
examples/trade/action/action_rl.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Static_Action(Base_Action):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
self.action_num = config["action_num"]
|
||||
self.action_map = config["action_map"]
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Discrete(self.action_num)
|
||||
|
||||
def get_action(self, action, target, position, **kargs):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
:param position:
|
||||
:param target:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return min(target * self.action_map[action], position)
|
||||
46
examples/trade/action/action_rule.py
Normal file
46
examples/trade/action/action_rule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Rule_Dynamic(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, max_step_num, t, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param max_step_num:
|
||||
:param t: param **kargs:
|
||||
:param target:
|
||||
:param max_step_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (max_step_num - (t + 1)) * action
|
||||
|
||||
|
||||
class Rule_Static(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, max_step_num, t, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param max_step_num:
|
||||
:param t: param **kargs:
|
||||
:param target:
|
||||
:param max_step_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return target / max_step_num * action
|
||||
20
examples/trade/action/base.py
Normal file
20
examples/trade/action/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
|
||||
class Base_Action(object):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
return
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_action(*args, **kargs)
|
||||
|
||||
def get_action(self, action):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
|
||||
"""
|
||||
return action
|
||||
46
examples/trade/action/interval_rule.py
Normal file
46
examples/trade/action/interval_rule.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
|
||||
from .base import Base_Action
|
||||
|
||||
|
||||
class Rule_Static_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, interval_num, interval, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param interval_num:
|
||||
:param interval: param **kargs:
|
||||
:param target:
|
||||
:param interval_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return target / (interval_num) * action
|
||||
|
||||
|
||||
class Rule_Dynamic_Interval(Base_Action):
|
||||
""" """
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return Box(0, np.inf, shape=(), dtype=np.float32)
|
||||
|
||||
def get_action(self, action, target, position, interval_num, interval, **kargs):
|
||||
"""
|
||||
|
||||
:param action: param target:
|
||||
:param position: param interval_num:
|
||||
:param interval: param **kargs:
|
||||
:param target:
|
||||
:param interval_num:
|
||||
:param **kargs:
|
||||
|
||||
"""
|
||||
return position / (interval_num - interval) * action
|
||||
1
examples/trade/agent/__init__.py
Normal file
1
examples/trade/agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .basic import *
|
||||
69
examples/trade/agent/basic.py
Normal file
69
examples/trade/agent/basic.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch
|
||||
import numpy as np
|
||||
import torch
|
||||
from env import nan_weighted_avg
|
||||
|
||||
|
||||
class TWAP(BasePolicy):
|
||||
""" The TWAP strategy. """
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.max_step_num = config["max_step_num"]
|
||||
self.num_cpus = config["num_cpus"]
|
||||
|
||||
# @njit(parallel=True)
|
||||
def forward(self, batch: Batch, state=None, **kwargs) -> Batch:
|
||||
act = [1] * len(batch.obs.private)
|
||||
return Batch(act=act, state=state)
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class VWAP(BasePolicy):
|
||||
""" The VWAP strategy."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, batch, state, **kwargs):
|
||||
obs = batch.obs
|
||||
r = np.stack(obs.prediction).reshape(-1)
|
||||
return Batch(act=r, state=state)
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
pass
|
||||
|
||||
|
||||
class AC(VWAP):
|
||||
"""Almgren-Chriss strategy."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.T = config["max_step_num"]
|
||||
self.gamma = 0
|
||||
self.tau = 1
|
||||
self.lamb = config["lambda"]
|
||||
self.eps = 0.0625
|
||||
self.alpha = 0.02
|
||||
self.eta = 2.5e-6
|
||||
|
||||
def forward(self, batch, state, **kwargs):
|
||||
obs = batch.obs
|
||||
sig = np.stack(obs.prediction).reshape(-1)
|
||||
sell = ~np.stack(obs.is_buy).astype(np.bool)
|
||||
data = np.stack(obs.private)
|
||||
t = data[:, 2]
|
||||
t = t + 1
|
||||
k_tild = self.lamb / self.eta * sig * sig
|
||||
k = np.arccosh(k_tild / 2 + 1)
|
||||
act = (np.sinh(k * (self.T - t)) - np.sinh(k * (self.T - t - 1))) / np.sinh(k * self.T)
|
||||
return Batch(act=act, state=state)
|
||||
342
examples/trade/collector.py
Normal file
342
examples/trade/collector.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import gym
|
||||
import time
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from vecenv import BaseVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.data.collector import _batch_set_item
|
||||
|
||||
|
||||
class Collector(object):
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
env: Union[gym.Env, BaseVectorEnv],
|
||||
testing=False,
|
||||
buffer: Optional[ReplayBuffer] = None,
|
||||
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
||||
action_noise: Optional[BaseNoise] = None,
|
||||
reward_metric: Optional[Callable[[np.ndarray], float]] = np.sum,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(env, BaseVectorEnv):
|
||||
env = DummyVectorEnv([lambda: env])
|
||||
self.env = env
|
||||
self.env_num = len(env)
|
||||
# environments that are available in step()
|
||||
# this means all environments in synchronous simulation
|
||||
# but only a subset of environments in asynchronous simulation
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
# self.async is a flag to indicate whether this collector works
|
||||
# with asynchronous simulation
|
||||
self.is_async = env.is_async
|
||||
self.testing = testing
|
||||
# need cache buffers before storing in the main buffer
|
||||
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.process_fn = policy.process_fn
|
||||
# self._action_space = env.action_space
|
||||
self._action_noise = action_noise
|
||||
self._rew_metric = reward_metric or Collector._default_rew_metric
|
||||
# avoid creating attribute outside __init__
|
||||
# self.reset()
|
||||
|
||||
@staticmethod
|
||||
def _default_rew_metric(x: Union[Number, np.number]) -> Union[Number, np.number]:
|
||||
# this internal function is designed for single-agent RL
|
||||
# for multi-agent RL, a reward_metric must be provided
|
||||
assert np.asanyarray(x).size == 1, "Please specify the reward_metric " "since the reward is not a scalar."
|
||||
return x
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all related variables in the collector."""
|
||||
# use empty Batch for ``state`` so that ``self.data`` supports slicing
|
||||
# convert empty Batch to None when passing data to policy
|
||||
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, obs_next={}, policy={})
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
self.reset_stat()
|
||||
if self._action_noise is not None:
|
||||
self._action_noise.reset()
|
||||
|
||||
def reset_stat(self) -> None:
|
||||
"""Reset the statistic variables."""
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
||||
|
||||
def reset_buffer(self) -> None:
|
||||
"""Reset the main data buffer."""
|
||||
if self.buffer is not None:
|
||||
self.buffer.reset()
|
||||
|
||||
def get_env_num(self) -> int:
|
||||
""" """
|
||||
return self.env_num
|
||||
|
||||
def reset_env(self) -> None:
|
||||
"""Reset all of the environment(s)' states and the cache buffers."""
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
self.env.reset_sampler()
|
||||
obs, stop_id = self.env.reset()
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs).get("obs", obs)
|
||||
self.data.obs = obs
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in stop_id])
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
"""Reset the hidden state: self.data.state[id]."""
|
||||
state = self.data.state # it is a reference
|
||||
if isinstance(state, torch.Tensor):
|
||||
state[id].zero_()
|
||||
elif isinstance(state, np.ndarray):
|
||||
state[id] = None if state.dtype == np.object else 0
|
||||
elif isinstance(state, Batch):
|
||||
state.empty_(id)
|
||||
|
||||
def collect(
|
||||
self,
|
||||
n_step: Optional[int] = None,
|
||||
n_episode: Optional[Union[int, List[int]]] = None,
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
log_fn=None,
|
||||
no_grad: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
:param int: n_step: how many steps you want to collect.
|
||||
:param n_episode: how many episodes you want to collect. If it is an
|
||||
int, it means to collect at lease ``n_episode`` episodes; if it is
|
||||
a list, it means to collect exactly ``n_episode[i]`` episodes in
|
||||
the i-th environment
|
||||
:param bool: random: whether to use random policy for collecting data,
|
||||
defaults to False.
|
||||
:param float: render: the sleep time between rendering consecutive
|
||||
frames, defaults to None (no rendering).
|
||||
:param bool: no_grad: whether to retain gradient in policy.forward,
|
||||
defaults to True (no gradient retaining).
|
||||
|
||||
.. note::
|
||||
|
||||
One and only one collection number specification is permitted,
|
||||
either ``n_step`` or ``n_episode``.
|
||||
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:List[int]]]: (Default value = None)
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param log_fn: Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:
|
||||
:param List[int]]]: (Default value = None)
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:param n_step: Optional[int]: (Default value = None)
|
||||
:param n_episode: Optional[Union[int:
|
||||
:param random: bool: (Default value = False)
|
||||
:param render: Optional[float]: (Default value = None)
|
||||
:param no_grad: bool: (Default value = True)
|
||||
:returns: A dict including the following keys
|
||||
|
||||
* ``n/ep`` the collected number of episodes.
|
||||
* ``n/st`` the collected number of steps.
|
||||
* ``v/st`` the speed of steps per second.
|
||||
* ``v/ep`` the speed of episode per second.
|
||||
* ``rew`` the mean reward over collected episodes.
|
||||
* ``len`` the mean length over collected episodes.
|
||||
|
||||
"""
|
||||
assert (
|
||||
(n_step is not None and n_episode is None and n_step > 0)
|
||||
or (n_step is None and n_episode is not None and np.sum(n_episode) > 0)
|
||||
or self.testing
|
||||
), "Only one of n_step or n_episode is allowed in Collector.collect, "
|
||||
f"got n_step = {n_step}, n_episode = {n_episode}."
|
||||
start_time = time.time()
|
||||
step_count = 0
|
||||
step_time = 0.0
|
||||
reset_time = 0.0
|
||||
model_time = 0.0
|
||||
# episode of each environment
|
||||
episode_count = np.zeros(self.env_num)
|
||||
# If n_episode is a list, and some envs have collected the required
|
||||
# number of episodes, these envs will be recorded in this list, and
|
||||
# they will not be stepped.
|
||||
finished_env_ids = []
|
||||
rewards = []
|
||||
whole_data = Batch()
|
||||
if isinstance(n_episode, list):
|
||||
assert len(n_episode) == self.get_env_num()
|
||||
finished_env_ids = [i for i in self._ready_env_ids if n_episode[i] <= 0]
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
while True:
|
||||
if step_count >= 100000 and episode_count.sum() == 0:
|
||||
warnings.warn(
|
||||
"There are already many steps in an episode. "
|
||||
"You should add a time limitation to your environment!",
|
||||
Warning,
|
||||
)
|
||||
|
||||
is_async = self.is_async or len(finished_env_ids) > 0
|
||||
if is_async:
|
||||
# self.data are the data for all environments in async
|
||||
# simulation or some envs have finished,
|
||||
# **only a subset of data are disposed**,
|
||||
# so we store the whole data in ``whole_data``, let self.data
|
||||
# to be the data available in ready environments, and finally
|
||||
# set these back into all the data
|
||||
whole_data = self.data
|
||||
self.data = self.data[self._ready_env_ids]
|
||||
|
||||
# restore the state and the input data
|
||||
last_state = self.data.state
|
||||
if isinstance(last_state, Batch) and last_state.is_empty():
|
||||
last_state = None
|
||||
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
|
||||
|
||||
# calculate the next action
|
||||
start = time.time()
|
||||
if random:
|
||||
spaces = self._action_space
|
||||
result = Batch(act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||
else:
|
||||
if no_grad:
|
||||
with torch.no_grad(): # faster than retain_grad version
|
||||
result = self.policy(self.data, last_state)
|
||||
else:
|
||||
result = self.policy(self.data, last_state)
|
||||
model_time += time.time() - start
|
||||
state = result.get("state", Batch())
|
||||
# convert None to Batch(), since None is reserved for 0-init
|
||||
if state is None:
|
||||
state = Batch()
|
||||
self.data.update(state=state, policy=result.get("policy", Batch()))
|
||||
# save hidden state to policy._state, in order to save into buffer
|
||||
if not (isinstance(state, Batch) and state.is_empty()):
|
||||
self.data.policy._state = self.data.state
|
||||
|
||||
self.data.act = to_numpy(result.act)
|
||||
if self._action_noise is not None:
|
||||
assert isinstance(self.data.act, np.ndarray)
|
||||
self.data.act += self._action_noise(self.data.act.shape)
|
||||
|
||||
# step in env
|
||||
start = time.time()
|
||||
if not is_async:
|
||||
obs_next, rew, done, info = self.env.step(self.data.act)
|
||||
if log_fn:
|
||||
log_fn(info)
|
||||
else:
|
||||
# store computed actions, states, etc
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# fetch finished data
|
||||
obs_next, rew, done, info = self.env.step(self.data.act, id=self._ready_env_ids)
|
||||
self._ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
# get the stepped data
|
||||
self.data = whole_data[self._ready_env_ids]
|
||||
if log_fn:
|
||||
log_fn(info)
|
||||
|
||||
step_time += time.time() - start
|
||||
# move data to self.data
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=[{} for i in info])
|
||||
|
||||
if render:
|
||||
self.env.render()
|
||||
time.sleep(render)
|
||||
|
||||
# add data into the buffer
|
||||
if self.preprocess_fn:
|
||||
result = self.preprocess_fn(**self.data) # type: ignore
|
||||
self.data.update(result)
|
||||
|
||||
for j, i in enumerate(self._ready_env_ids):
|
||||
# j is the index in current ready_env_ids
|
||||
# i is the index in all environments
|
||||
if self.buffer is None:
|
||||
# users do not want to store data, so we store
|
||||
# small fake data here to make the code clean
|
||||
self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0)
|
||||
else:
|
||||
self._cached_buf[i].add(**self.data[j])
|
||||
|
||||
if done[j]:
|
||||
if not (isinstance(n_episode, list) and episode_count[i] >= n_episode[i]):
|
||||
episode_count[i] += 1
|
||||
rewards.append(self._rew_metric(np.sum(self._cached_buf[i].rew, axis=0)))
|
||||
step_count += len(self._cached_buf[i])
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
if isinstance(n_episode, list) and episode_count[i] >= n_episode[i]:
|
||||
# env i has collected enough data, it has finished
|
||||
finished_env_ids.append(i)
|
||||
self._cached_buf[i].reset()
|
||||
self._reset_state(j)
|
||||
obs_next = self.data.obs_next
|
||||
start = time.time()
|
||||
if sum(done):
|
||||
env_ind_local = np.where(done)[0].tolist()
|
||||
env_ind_global = self._ready_env_ids[env_ind_local]
|
||||
obs_reset, stop_id = self.env.reset(env_ind_global)
|
||||
_ready_env_ids = self._ready_env_ids.tolist()
|
||||
for i in stop_id:
|
||||
finished_env_ids.append(i)
|
||||
# env_ind_local.remove(_ready_env_ids.index(i))
|
||||
if len(env_ind_local) > 0:
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
reset_time += time.time() - start
|
||||
self.data.obs = obs_next
|
||||
if is_async:
|
||||
# set data back
|
||||
whole_data = deepcopy(whole_data) # avoid reference in ListBuf
|
||||
_batch_set_item(whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# let self.data be the data in all environments again
|
||||
self.data = whole_data
|
||||
self._ready_env_ids = np.array([x for x in self._ready_env_ids if x not in finished_env_ids])
|
||||
if n_step:
|
||||
if step_count >= n_step:
|
||||
break
|
||||
else:
|
||||
if isinstance(n_episode, int) and episode_count.sum() >= n_episode:
|
||||
break
|
||||
if isinstance(n_episode, list) and (episode_count >= n_episode).all():
|
||||
break
|
||||
if len(self._ready_env_ids) == 0 and self.testing:
|
||||
break
|
||||
|
||||
# finished envs are ready, and can be used for the next collection
|
||||
self._ready_env_ids = np.array(self._ready_env_ids.tolist() + finished_env_ids)
|
||||
|
||||
# generate the statistics
|
||||
episode_count = sum(episode_count)
|
||||
duration = max(time.time() - start_time, 1e-9)
|
||||
self.collect_step += step_count
|
||||
self.collect_episode += episode_count
|
||||
self.collect_time += duration
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
"n/st": step_count,
|
||||
"v/st": step_count / duration,
|
||||
"v/ep": episode_count / duration,
|
||||
"t/st": step_time / step_count,
|
||||
"t/re": reset_time / episode_count,
|
||||
"t/mo": model_time / step_count,
|
||||
"rew": np.mean(rewards),
|
||||
"rew_std": np.std(rewards),
|
||||
"len": step_count / episode_count,
|
||||
}
|
||||
1
examples/trade/env/__init__.py
vendored
Normal file
1
examples/trade/env/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
from .env_rl import *
|
||||
481
examples/trade/env/env_rl.py
vendored
Normal file
481
examples/trade/env/env_rl.py
vendored
Normal file
@@ -0,0 +1,481 @@
|
||||
import gym
|
||||
|
||||
gym.logger.set_level(40)
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle as pkl
|
||||
import datetime
|
||||
import random
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import tianshou as ts
|
||||
import copy
|
||||
from multiprocessing import Process, Pipe, Queue
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
from scipy.stats import pearsonr
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import merge_dicts, nan_weighted_avg, robust_auc
|
||||
import reward
|
||||
import observation
|
||||
import action
|
||||
|
||||
ZERO = 1e-7
|
||||
|
||||
|
||||
class StockEnv(gym.Env):
|
||||
"""Single-assert environment"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.max_step_num = config["max_step_num"]
|
||||
self.limit = config["limit"]
|
||||
self.time_interval = config["time_interval"]
|
||||
self.interval_num = config["interval_num"]
|
||||
self.offset = config["offset"] if "offset" in config else 0
|
||||
if "last_reward" in config:
|
||||
self.last_reward = config["last_reward"]
|
||||
else:
|
||||
self.last_reward = None
|
||||
if "log" in config:
|
||||
self.log = config["log"]
|
||||
else:
|
||||
self.log = True
|
||||
# loader_conf = config['loader']['config']
|
||||
obs_conf = config["obs"]["config"]
|
||||
obs_conf["features"] = config["features"]
|
||||
obs_conf["time_interval"] = self.time_interval
|
||||
obs_conf["max_step_num"] = self.max_step_num
|
||||
self.obs = getattr(observation, config["obs"]["name"])(obs_conf)
|
||||
self.action_func = getattr(action, config["action"]["name"])(config["action"]["config"])
|
||||
self.reward_func_list = []
|
||||
self.reward_log_dict = {}
|
||||
self.reward_coef = []
|
||||
for name, conf in config["reward"].items():
|
||||
self.reward_coef.append(conf.pop("coefficient"))
|
||||
self.reward_func_list.append(getattr(reward, name)(conf))
|
||||
self.reward_log_dict[name] = 0.0
|
||||
self.observation_space = self.obs.get_space()
|
||||
self.action_space = self.action_func.get_space()
|
||||
|
||||
def toggle_log(self, log):
|
||||
self.log = log
|
||||
|
||||
def reset(self, sample):
|
||||
"""
|
||||
|
||||
:param sample:
|
||||
|
||||
"""
|
||||
|
||||
for key in self.reward_log_dict.keys():
|
||||
self.reward_log_dict[key] = 0.0
|
||||
if not sample is None:
|
||||
(
|
||||
self.ins,
|
||||
self.date,
|
||||
self.raw_df_values,
|
||||
self.raw_df_columns,
|
||||
self.raw_df_index,
|
||||
self.feature_dfs,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
) = sample
|
||||
self.raw_df = pd.DataFrame(index=self.raw_df_index, data=self.raw_df_values, columns=self.raw_df_columns,)
|
||||
del self.raw_df_values, self.raw_df_columns, self.raw_df_index
|
||||
start_time = time.time()
|
||||
self.load_time = time.time() - start_time
|
||||
self.day_vwap = nan_weighted_avg(
|
||||
self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
self.raw_df["$volume0"].values[self.offset : self.offset + self.max_step_num],
|
||||
)
|
||||
try:
|
||||
assert not (np.isnan(self.day_vwap) or np.isinf(self.day_vwap))
|
||||
except:
|
||||
print(self.raw_df)
|
||||
print(self.ins)
|
||||
print(self.day_vwap)
|
||||
self.raw_df.to_pickle("/nfs_data1/kanren/error_df.pkl")
|
||||
self.day_twap = np.nanmean(self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num])
|
||||
self.t = -1 + self.offset
|
||||
self.interval = 0
|
||||
self.position = self.target
|
||||
self.eps_start = time.time()
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
)
|
||||
if self.log:
|
||||
index_array = [
|
||||
np.array([self.ins] * self.max_step_num),
|
||||
self.raw_df.index.to_numpy()[self.offset : self.offset + self.max_step_num],
|
||||
np.array([self.date] * self.max_step_num),
|
||||
]
|
||||
self.traded_log = pd.DataFrame(
|
||||
data={
|
||||
"$v_t": np.nan,
|
||||
"$max_vol_t": (self.raw_df["$volume0"] * self.limit).values[
|
||||
self.offset : self.offset + self.max_step_num
|
||||
],
|
||||
"$traded_t": np.nan,
|
||||
"$vwap_t": self.raw_df["$vwap0"].values[self.offset : self.offset + self.max_step_num],
|
||||
"action": np.nan,
|
||||
},
|
||||
index=index_array,
|
||||
)
|
||||
# v_t: The amount of shares the agent hope to trade
|
||||
# max_vol_t: The max amount of shares can be traded
|
||||
# traded_t: The amount of shares that is acually traded
|
||||
# action: the action of agent, may have various meanings in different settings.
|
||||
self.done = False
|
||||
if self.limit > 1:
|
||||
self.this_valid = np.inf
|
||||
else:
|
||||
self.this_valid = np.nansum(self.raw_df["$volume0"].values) * self.limit
|
||||
self.this_cash = 0
|
||||
|
||||
self.step_time = []
|
||||
self.action_log = [np.nan] * self.interval_num
|
||||
self.reset_time = time.time() - start_time
|
||||
self.real_eps_time = self.reset_time
|
||||
self.total_reward = 0
|
||||
self.total_instant_rew = 0
|
||||
self.last_rew = 0
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
|
||||
:param action:
|
||||
|
||||
"""
|
||||
start_time = time.time()
|
||||
self.action_log[self.interval] = action
|
||||
volume_t = self.action_func(
|
||||
action,
|
||||
self.target,
|
||||
self.position,
|
||||
max_step_num=self.max_step_num,
|
||||
t=self.t - self.offset,
|
||||
interval=self.interval,
|
||||
interval_num=self.interval_num,
|
||||
)
|
||||
self.interval += 1
|
||||
reward = 0.0
|
||||
time_left = self.max_step_num - self.t - 1 + self.offset
|
||||
|
||||
for i in range(self.time_interval):
|
||||
v_t = volume_t / min(self.time_interval, time_left)
|
||||
self.t += 1
|
||||
if self.t == self.max_step_num - 1 + self.offset:
|
||||
v_t = self.position
|
||||
if self.log:
|
||||
log_index = self.t - self.offset
|
||||
self.traded_log.iat[log_index, 0] = v_t
|
||||
self.traded_log.iat[log_index, 4] = action
|
||||
vwap_t, vol_t = self.raw_df.iloc[self.t][["$vwap0", "$volume0"]]
|
||||
max_vol_t = self.limit * vol_t
|
||||
if self.limit >= 1:
|
||||
max_vol_t = np.inf
|
||||
if v_t > min(self.position, max_vol_t):
|
||||
if self.position <= max_vol_t:
|
||||
v_t = self.position
|
||||
else:
|
||||
v_t = max_vol_t
|
||||
self.position -= v_t
|
||||
self.this_cash += vwap_t * v_t
|
||||
if self.log:
|
||||
self.traded_log.iat[log_index, 2] = v_t
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - vwap_t / self.day_vwap) * 10000
|
||||
PA_t = (1 - vwap_t / self.day_twap) * 10000
|
||||
else:
|
||||
performance_raise = (vwap_t / self.day_vwap - 1) * 10000
|
||||
PA_t = (vwap_t / self.day_twap - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, v_t, self.target, PA_t)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
if self.t == self.max_step_num - 1 + self.offset:
|
||||
break
|
||||
|
||||
if self.position < ZERO:
|
||||
self.done = True
|
||||
|
||||
if self.interval == self.interval_num:
|
||||
self.done = True
|
||||
|
||||
self.step_time.append(time.time() - start_time)
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vv_ratio) * 10000
|
||||
PA = (1 - this_tt_ratio) * 10000
|
||||
else:
|
||||
performance_raise = (this_vv_ratio - 1) * 10000
|
||||
PA = (this_tt_ratio - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
if self.log:
|
||||
res = pd.DataFrame(
|
||||
{
|
||||
"target": self.target,
|
||||
"sell": not self.is_buy,
|
||||
"vwap": this_vwap,
|
||||
"this_vv_ratio": this_vv_ratio,
|
||||
"this_ffr": this_ffr,
|
||||
},
|
||||
index=[[self.ins], [self.date]],
|
||||
)
|
||||
money = self.target * self.day_vwap
|
||||
if self.is_buy:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_buy": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_buy": this_ffr,
|
||||
"PR_buy": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_buy": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_sell": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_sell": this_ffr,
|
||||
"PR_sell": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_sell": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
info = merge_dicts(info, self.reward_log_dict)
|
||||
if self.log:
|
||||
info["df"] = self.traded_log
|
||||
info["res"] = res
|
||||
del self.feature_dfs
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
else:
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
return self.state, reward, self.done, {}
|
||||
|
||||
|
||||
class StockEnv_Acc(StockEnv):
|
||||
def step(self, action):
|
||||
start_time = time.time()
|
||||
self.action_log[self.interval] = action
|
||||
volume_t = self.action_func(
|
||||
action,
|
||||
self.target,
|
||||
self.position,
|
||||
max_step_num=self.max_step_num,
|
||||
t=self.t - self.offset,
|
||||
interval=self.interval,
|
||||
interval_num=self.interval_num,
|
||||
)
|
||||
self.interval += 1
|
||||
reward = 0.0
|
||||
time_left = self.max_step_num - self.t - 1 + self.offset
|
||||
time_left = min(self.time_interval, time_left)
|
||||
|
||||
v_t = np.repeat(volume_t / time_left, time_left)
|
||||
minutes = np.arange(self.t + 1, self.t + time_left + 1)
|
||||
if self.log:
|
||||
log_index = minutes - self.offset
|
||||
self.traded_log.iloc[log_index, 0] = v_t
|
||||
self.traded_log.iloc[log_index, 4] = action
|
||||
vwap_t = self.raw_df.iloc[minutes]["$vwap0"].values
|
||||
vol_t = self.raw_df.iloc[minutes]["$volume0"].values
|
||||
max_vol_t = self.limit * vol_t if self.limit < 1 else np.inf
|
||||
v_t = np.minimum(v_t, max_vol_t)
|
||||
if self.t + time_left == self.max_step_num - 1 + self.offset:
|
||||
left = self.position - v_t.sum()
|
||||
v_t[-1] += left
|
||||
v_t = np.minimum(v_t, max_vol_t)
|
||||
this_money = (v_t * vwap_t).sum()
|
||||
this_vol = v_t.sum()
|
||||
this_vwap = np.nan_to_num(this_money / this_vol)
|
||||
self.t += time_left
|
||||
self.position -= this_vol
|
||||
self.this_cash += this_money
|
||||
if self.log:
|
||||
self.traded_log.iloc[log_index, 2] = v_t
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vwap / self.day_vwap) * 10000
|
||||
PA_t = (1 - this_vwap / self.day_twap) * 10000
|
||||
else:
|
||||
performance_raise = (this_vwap / self.day_vwap - 1) * 10000
|
||||
PA_t = (this_vwap / self.day_twap - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, v_t, self.target, PA_t)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
if self.position < ZERO:
|
||||
self.done = True
|
||||
|
||||
if self.interval == self.interval_num:
|
||||
self.done = True
|
||||
|
||||
self.step_time.append(time.time() - start_time)
|
||||
self.real_eps_time += time.time() - start_time
|
||||
if self.done:
|
||||
this_traded = self.target - self.position
|
||||
this_vwap = (self.this_cash / this_traded) if this_traded > ZERO else self.day_vwap
|
||||
valid = min(self.target, self.this_valid)
|
||||
this_ffr = (this_traded / valid) if valid > ZERO else 1.0
|
||||
if abs(this_ffr - 1.0) < ZERO:
|
||||
this_ffr = 1.0
|
||||
this_ffr *= 100
|
||||
this_vv_ratio = this_vwap / self.day_vwap
|
||||
vwap = self.raw_df["$vwap0"].values[self.offset : self.max_step_num + self.offset]
|
||||
this_tt_ratio = this_vwap / np.nanmean(vwap)
|
||||
|
||||
if self.is_buy:
|
||||
performance_raise = (1 - this_vv_ratio) * 10000
|
||||
PA = (1 - this_tt_ratio) * 10000
|
||||
else:
|
||||
performance_raise = (this_vv_ratio - 1) * 10000
|
||||
PA = (this_tt_ratio - 1) * 10000
|
||||
|
||||
for i, reward_func in enumerate(self.reward_func_list):
|
||||
if not reward_func.isinstant:
|
||||
tmp_r = reward_func(performance_raise, this_ffr, this_tt_ratio, self.is_buy)
|
||||
reward += tmp_r * self.reward_coef[i]
|
||||
self.reward_log_dict[type(reward_func).__name__] += tmp_r
|
||||
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
if self.log:
|
||||
res = pd.DataFrame(
|
||||
{
|
||||
"target": self.target,
|
||||
"sell": not self.is_buy,
|
||||
"vwap": this_vwap,
|
||||
"this_vv_ratio": this_vv_ratio,
|
||||
"this_ffr": this_ffr,
|
||||
},
|
||||
index=[[self.ins], [self.date]],
|
||||
)
|
||||
money = self.target * self.day_vwap
|
||||
if self.is_buy:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_buy": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_buy": this_ffr,
|
||||
"PR_buy": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_buy": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
else:
|
||||
info = {
|
||||
"money": money,
|
||||
"money_sell": money,
|
||||
"action": self.action_log,
|
||||
"ffr": this_ffr,
|
||||
"obs0_PR": performance_raise,
|
||||
"ffr_sell": this_ffr,
|
||||
"PR_sell": performance_raise,
|
||||
"PA": PA,
|
||||
"PA_sell": PA,
|
||||
"vwap": this_vwap,
|
||||
}
|
||||
info = merge_dicts(info, self.reward_log_dict)
|
||||
if self.log:
|
||||
info["df"] = self.traded_log
|
||||
info["res"] = res
|
||||
del self.feature_dfs
|
||||
return self.state, reward, self.done, info
|
||||
|
||||
else:
|
||||
self.state = self.obs(
|
||||
self.raw_df,
|
||||
self.feature_dfs,
|
||||
self.t,
|
||||
self.interval,
|
||||
self.position,
|
||||
self.target,
|
||||
self.is_buy,
|
||||
self.max_step_num,
|
||||
self.interval_num,
|
||||
action,
|
||||
)
|
||||
return self.state, reward, self.done, {}
|
||||
351
examples/trade/executor.py
Normal file
351
examples/trade/executor.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import env
|
||||
from vecenv import *
|
||||
import sampler
|
||||
import logger
|
||||
import json
|
||||
import os
|
||||
import agent
|
||||
import network
|
||||
import policy
|
||||
import random
|
||||
import tianshou as ts
|
||||
import tqdm
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from collector import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
from util import merge_dicts
|
||||
|
||||
|
||||
def get_best_gpu(force=None):
|
||||
if force is not None:
|
||||
return force
|
||||
s = os.popen("nvidia-smi --query-gpu=memory.free --format=csv")
|
||||
a = []
|
||||
ss = s.read().replace("MiB", "").replace("memory.free", "").split("\n")
|
||||
s.close()
|
||||
for i in range(1, len(ss) - 1):
|
||||
a.append(int(ss[i]))
|
||||
best = int(np.argmax(a))
|
||||
print("the best GPU is ", best, " with free memories of ", ss[best + 1])
|
||||
return best
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
"""
|
||||
|
||||
:param seed:
|
||||
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
class BaseExecutor(object):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network_conf=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
):
|
||||
"""A base class for executor
|
||||
|
||||
:param log_dir: The directory to write all the logs.
|
||||
:type log_dir: string
|
||||
:param resources: A dict which describes available computational resources.
|
||||
:type resources: dict
|
||||
:param env_conf: Configurations for the envionments.
|
||||
:type env_conf: dict
|
||||
:param optim: Optimization configuration, defaults to None
|
||||
:type optim: dict, optional
|
||||
:param policy_conf: Configurations for the RL algorithm, defaults to None
|
||||
:type policy_conf: dict, optional
|
||||
:param network_conf: Configurations for policy network_conf, defaults to None
|
||||
:type network_conf: dict, optional
|
||||
:param policy_path: If is not None, would load the policy from this path, defaults to None
|
||||
:type policy_path: string, optional
|
||||
:param seed: Random seed, defaults to None
|
||||
:type seed: int, optional
|
||||
"""
|
||||
# self.config = config
|
||||
self.log_dir = log_dir
|
||||
print(self.log_dir)
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
if resources["device"] == "cuda":
|
||||
resources["device"] = "cuda:" + str(get_best_gpu())
|
||||
self.device = torch.device(resources["device"])
|
||||
if seed:
|
||||
setup_seed(seed)
|
||||
|
||||
assert not policy_path is None or not policy_conf is None, "Policy must be defined"
|
||||
if policy_path:
|
||||
self.policy = torch.load(policy_path, map_location=self.device)
|
||||
self.policy.actor.extractor.device = self.device
|
||||
# policy.eval()
|
||||
elif hasattr(agent, policy_conf["name"]):
|
||||
policy_conf["config"] = merge_dicts(policy_conf["config"], resources)
|
||||
self.policy = getattr(agent, policy_conf["name"])(policy_conf["config"])
|
||||
# print(self.policy)
|
||||
else:
|
||||
assert not network_conf is None
|
||||
if "extractor" in network_conf.keys():
|
||||
net = getattr(network, network_conf["extractor"]["name"] + "_Extractor")(
|
||||
device=self.device, **network_conf["config"]
|
||||
)
|
||||
else:
|
||||
net = getattr(network, network_conf["name"] + "_Extractor")(
|
||||
device=self.device, **network_conf["config"]
|
||||
)
|
||||
net.to(self.device)
|
||||
actor = getattr(network, network_conf["name"] + "_Actor")(
|
||||
extractor=net, device=self.device, **network_conf["config"]
|
||||
)
|
||||
actor.to(self.device)
|
||||
critic = getattr(network, network_conf["name"] + "_Critic")(
|
||||
extractor=net, device=self.device, **network_conf["config"]
|
||||
)
|
||||
critic.to(self.device)
|
||||
self.optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()),
|
||||
lr=optim["lr"],
|
||||
weight_decay=optim["weight_decay"] if "weight_decay" in optim else 0.0,
|
||||
)
|
||||
self.dist = torch.distributions.Categorical
|
||||
try:
|
||||
self.policy = getattr(ts.policy, policy_conf["name"])(
|
||||
actor, critic, self.optim, self.dist, **policy_conf["config"]
|
||||
)
|
||||
except:
|
||||
self.policy = getattr(policy, policy_conf["name"])(
|
||||
actor, critic, self.optim, self.dist, **policy_conf["config"]
|
||||
)
|
||||
self.writer = SummaryWriter(self.log_dir)
|
||||
|
||||
def train(
|
||||
self,
|
||||
max_epoch,
|
||||
step_per_epoch,
|
||||
repeat_per_collect,
|
||||
collect_per_step,
|
||||
batch_size,
|
||||
iteration=0,
|
||||
global_step=0,
|
||||
early_stopping=5,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
"""Run the whole training process.
|
||||
|
||||
:param max_epoch: The total number of epoch.
|
||||
:param step_per_epoch: The times of bp in one epoch.
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
:param repeat_per_collect: Times of bps after every rould of experience collecting.
|
||||
:param batch_size: Batch size when bp.
|
||||
:param iteration: The iteration when starting the training, used when fine tuning. (Default value = 0)
|
||||
:param global_step: The number of steps when starting the training, used when fine tuning. (Default value = 0)
|
||||
:param early_stopping: If the test reward does not reach a new high in `early_stopping` iterations, the training would stop. (Default value = 5)
|
||||
:returns: The result on test set.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
"""Do an round of training
|
||||
|
||||
:param collect_per_step: Number of episodes to collect before one bp.
|
||||
:param repeat_per_collect: Times of bps after every rould of experience collecting.
|
||||
:param batch_size: Batch size when bp.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def eval(self, order_dir, save_res=False, logdir=None, *args, **kargs):
|
||||
"""Evaluate the policy on orders in order_dir
|
||||
|
||||
:param order_dir: the orders to be evaluated on.
|
||||
:param save_res: whether the result of evaluation be saved to self.logdir/res.json (Default value = False)
|
||||
:param logdir: the place to save the .log and .pkl log files to. If None, don't save logfiles. (Default value = None)
|
||||
:returns: The result of evaluation.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Executor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
log_dir,
|
||||
resources,
|
||||
env_conf,
|
||||
train_paths,
|
||||
valid_paths,
|
||||
test_paths,
|
||||
io_conf,
|
||||
optim=None,
|
||||
policy_conf=None,
|
||||
network_conf=None,
|
||||
policy_path=None,
|
||||
seed=None,
|
||||
share_memory=False,
|
||||
buffer_size=200000,
|
||||
q_learning=False,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
"""[summary]
|
||||
|
||||
:param log_dir: The directory to write all the logs.
|
||||
:type log_dir: string
|
||||
:param resources: A dict which describes available computational resources.
|
||||
:type resources: dict
|
||||
:param env_conf: Configurations for the envionments.
|
||||
:type env_conf: dict
|
||||
:param train_paths: The paths of training datasets including orders, backtest files and features.
|
||||
:type train_paths: string
|
||||
:param valid_paths: The paths of validation datasets including orders, backtest files and features.
|
||||
:type valid_paths: string
|
||||
:param test_paths: The paths of test datasets including orders, backtest files and features.
|
||||
:type test_paths: string
|
||||
:param io_conf: Configuration for sampler and loggers.
|
||||
:type io_conf: dict
|
||||
:param share_memory: Whether to use shared memory vecnev, defaults to False
|
||||
:type share_memory: bool, optional
|
||||
:param buffer_size: The size of replay buffer, defaults to 200000
|
||||
:type buffer_size: int, optional
|
||||
"""
|
||||
super().__init__(log_dir, resources, env_conf, optim, policy_conf, network_conf, policy_path, seed)
|
||||
single_env = getattr(env, env_conf["name"])
|
||||
env_conf = merge_dicts(env_conf, train_paths)
|
||||
env_conf["log"] = True
|
||||
print("CPU_COUNT:", resources["num_cpus"])
|
||||
if share_memory:
|
||||
self.env = ShmemVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
else:
|
||||
self.env = SubprocVectorEnv([lambda: single_env(env_conf) for _ in range(resources["num_cpus"])])
|
||||
self.test_collector = Collector(policy=self.policy, env=self.env, testing=True, reward_metric=np.sum)
|
||||
self.train_collector = Collector(
|
||||
self.policy, self.env, buffer=ts.data.ReplayBuffer(buffer_size), reward_metric=np.sum,
|
||||
)
|
||||
self.train_paths = train_paths
|
||||
self.test_paths = test_paths
|
||||
self.valid_paths = valid_paths
|
||||
train_sampler_conf = train_paths
|
||||
train_sampler_conf["features"] = env_conf["features"]
|
||||
test_sampler_conf = test_paths
|
||||
test_sampler_conf["features"] = env_conf["features"]
|
||||
self.train_sampler = getattr(sampler, io_conf["train_sampler"])(train_sampler_conf)
|
||||
self.test_sampler = getattr(sampler, io_conf["test_sampler"])(test_sampler_conf)
|
||||
self.train_logger = logger.InfoLogger()
|
||||
self.test_logger = getattr(logger, io_conf["test_logger"])
|
||||
|
||||
self.q_learning = q_learning
|
||||
|
||||
def train(
|
||||
self,
|
||||
max_epoch,
|
||||
step_per_epoch,
|
||||
repeat_per_collect,
|
||||
collect_per_step,
|
||||
batch_size,
|
||||
iteration=0,
|
||||
global_step=0,
|
||||
early_stopping=5,
|
||||
train_step_min=0,
|
||||
log_valid=True,
|
||||
*args,
|
||||
**kargs,
|
||||
):
|
||||
best_epoch, best_reward = -1, -1
|
||||
stat = {}
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result, losses = self.train_round(repeat_per_collect, collect_per_step, batch_size, iteration)
|
||||
global_step += result["n/st"]
|
||||
iteration += 1
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Train/" + k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
self.writer.add_scalar("Train/" + k, stat[k].get(), global_step=global_step)
|
||||
t.update(1)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
result = self.eval(
|
||||
self.valid_paths["order_dir"], logdir=f"{self.log_dir}/valid/{iteration}/" if log_valid else None,
|
||||
)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Valid/" + k, result[k], global_step=global_step)
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward = result["rew"]
|
||||
best_epoch = epoch
|
||||
best_state = self.policy.state_dict()
|
||||
early_stop_round = 0
|
||||
torch.save(self.policy, f"{self.log_dir}/policy_best")
|
||||
elif global_step >= train_step_min:
|
||||
early_stop_round += 1
|
||||
torch.save(self.policy, f"{self.log_dir}/policy_{epoch}")
|
||||
print(
|
||||
f'Epoch #{epoch}: test_reward: {result["rew"]:.4f}, ' # train_reward: {result_train["rew"]:.4f}, '
|
||||
f"best_reward: {best_reward:.4f} in #{best_epoch}"
|
||||
)
|
||||
if early_stop_round >= early_stopping:
|
||||
print("Early stopped")
|
||||
break
|
||||
print("Testing...")
|
||||
self.policy.load_state_dict(best_state)
|
||||
result = self.eval(self.test_paths["order_dir"], logdir=f"{self.log_dir}/test/", save_res=True)
|
||||
for k in result.keys():
|
||||
self.writer.add_scalar("Test/" + k, result[k], global_step=global_step)
|
||||
return result
|
||||
|
||||
def train_round(self, repeat_per_collect, collect_per_step, batch_size, *args, **kargs):
|
||||
self.policy.train()
|
||||
self.env.toggle_log(False)
|
||||
self.env.sampler = self.train_sampler
|
||||
if not self.q_learning:
|
||||
self.train_collector.reset()
|
||||
result = self.train_collector.collect(n_episode=collect_per_step, log_fn=self.train_logger)
|
||||
result = merge_dicts(result, self.train_logger.summary())
|
||||
if not self.q_learning:
|
||||
losses = self.policy.update(
|
||||
0, self.train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect,
|
||||
)
|
||||
else:
|
||||
losses = self.policy.update(batch_size, self.train_collector.buffer,)
|
||||
return result, losses
|
||||
|
||||
def eval(self, order_dir, save_res=False, logdir=None, *args, **kargs):
|
||||
print(f"start evaluating on {order_dir}")
|
||||
self.policy.eval()
|
||||
self.env.toggle_log(True)
|
||||
self.test_sampler.reset(order_dir)
|
||||
self.env.sampler = self.test_sampler
|
||||
self.test_collector.reset()
|
||||
if not logdir is None:
|
||||
if not os.path.exists(logdir):
|
||||
os.makedirs(logdir)
|
||||
eval_logger = self.test_logger(logdir, order_dir)
|
||||
eval_logger.reset()
|
||||
else:
|
||||
eval_logger = self.train_logger
|
||||
result = self.test_collector.collect(log_fn=eval_logger)
|
||||
result = merge_dicts(result, eval_logger.summary())
|
||||
if save_res:
|
||||
with open(self.log_dir + "/res.json", "w") as f:
|
||||
json.dump(result, f, sort_keys=True, indent=4)
|
||||
print(f"finish evaluating on {order_dir}")
|
||||
return result
|
||||
76
examples/trade/exp/example/OPD/config.yml
Normal file
76
examples/trade/exp/example/OPD/config.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPD
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
- name: teacher_action
|
||||
type: interval
|
||||
size: 1
|
||||
loc: ../data/feature/teacher/
|
||||
obs:
|
||||
name: RuleTeacher
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO_sup
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
sup_coef: 0.01
|
||||
network_conf:
|
||||
name: OPD
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
71
examples/trade/exp/example/OPDS/config.yml
Normal file
71
examples/trade/exp/example/OPDS/config.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPDS
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: PPO
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
71
examples/trade/exp/example/OPDT/config.yml
Normal file
71
examples/trade/exp/example/OPDT/config.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/OPDT
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: Teacher
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
76
examples/trade/exp/example/OPDT_b/config.yml
Normal file
76
examples/trade/exp/example/OPDT_b/config.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
seed: 42
|
||||
task: eval
|
||||
log_dir: example/OPDT_b
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/all/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
VP_Penalty_small_vec:
|
||||
penalty: 100
|
||||
coefficient: 1
|
||||
policy_path: policy_best
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: Teacher
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
search:
|
||||
optim.weight_decay:
|
||||
type: choice
|
||||
value: [0.]
|
||||
70
examples/trade/exp/example/PPO/config.yml
Normal file
70
examples/trade/exp/example/PPO/config.yml
Normal file
@@ -0,0 +1,70 @@
|
||||
seed: 42
|
||||
task: train
|
||||
log_dir: example/PPO
|
||||
buffer_size: 80000
|
||||
io_conf:
|
||||
test_sampler: TestSampler
|
||||
train_sampler: Sampler
|
||||
test_logger: DFLogger
|
||||
resources:
|
||||
num_cpus: 24
|
||||
num_gpus: 1
|
||||
device: cuda
|
||||
train_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/train/
|
||||
valid_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/valid/
|
||||
test_paths:
|
||||
raw_dir: ../data/backtest/
|
||||
order_dir: ../data/order/test/
|
||||
env_conf:
|
||||
name: StockEnv_Acc
|
||||
max_step_num: 237
|
||||
limit: 10
|
||||
time_interval: 30
|
||||
interval_num: 8
|
||||
features:
|
||||
- name: raw
|
||||
type: range
|
||||
loc: ../data/normed_feature/
|
||||
size: 180
|
||||
obs:
|
||||
name: TeacherObs
|
||||
config: {}
|
||||
action:
|
||||
name: Static_Action
|
||||
config:
|
||||
action_num: 5
|
||||
action_map: [0, 0.25, 0.5, 0.75, 1]
|
||||
reward:
|
||||
PPO_Reward:
|
||||
coefficient: 1
|
||||
policy_conf:
|
||||
name: PPO
|
||||
config:
|
||||
discount_factor: 1.
|
||||
max_grad_norm: 100.
|
||||
reward_normalization: False
|
||||
eps_clip: 0.3
|
||||
value_clip: True
|
||||
vf_coef: 1.
|
||||
gae_lambda: 1.
|
||||
vf_clip_para: 0.3
|
||||
network_conf:
|
||||
name: PPO
|
||||
config:
|
||||
hidden_size: 64
|
||||
out_shape: 5
|
||||
fc_size: 32
|
||||
cnn_shape: [30, 6]
|
||||
optim:
|
||||
lr: 1e-4
|
||||
batch_size: 1024
|
||||
max_epoch: 30
|
||||
step_per_epoch: 20
|
||||
collect_per_step: 10000
|
||||
repeat_per_collect: 5
|
||||
early_stopping: 5
|
||||
weight_decay: 0.
|
||||
1
examples/trade/logger/__init__.py
Normal file
1
examples/trade/logger/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .single_logger import *
|
||||
231
examples/trade/logger/single_logger.py
Normal file
231
examples/trade/logger/single_logger.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
from multiprocessing import Queue, Process
|
||||
import time
|
||||
|
||||
|
||||
def GLR(values):
|
||||
"""
|
||||
|
||||
Calculate -P(value | value > 0) / P(value | value < 0)
|
||||
|
||||
"""
|
||||
pos = []
|
||||
neg = []
|
||||
for i in values:
|
||||
if i > 0:
|
||||
pos.append(i)
|
||||
elif i < 0:
|
||||
neg.append(i)
|
||||
return -np.mean(pos) / np.mean(neg)
|
||||
|
||||
|
||||
class DFLogger(object):
|
||||
"""The logger for single-assert backtest.
|
||||
Would save .pkl and .log in log_dir
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir, order_dir, writer=None):
|
||||
self.order_dir = order_dir + "/"
|
||||
self.log_dir = log_dir + "/"
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
self.queue = Queue(100000)
|
||||
self.raw_log_dir = self.log_dir
|
||||
|
||||
@staticmethod
|
||||
def _worker(log_dir, order_dir, queue):
|
||||
df_cache = {}
|
||||
stat_cache = {}
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
while True:
|
||||
info = queue.get(block=True)
|
||||
if info == "stop":
|
||||
summary = {}
|
||||
for k, v in stat_cache.items():
|
||||
if not k.startswith("money"):
|
||||
summary[k + "_std"] = np.nanstd(v)
|
||||
summary[k + "_mean"] = np.nanmean(v)
|
||||
try:
|
||||
for k in ["PR_sell", "ffr_sell", "PA_sell"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_sell"])
|
||||
except:
|
||||
# summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache['money_sell'])
|
||||
pass
|
||||
try:
|
||||
for k in ["PR_buy", "ffr_buy", "PA_buy"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_buy"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["obs0_PR", "ffr", "PA"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money"])
|
||||
except:
|
||||
pass
|
||||
summary["GLR"] = GLR(stat_cache["PA"])
|
||||
try:
|
||||
summary["GLR_sell"] = GLR(stat_cache["PA_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
summary["GLR_buy"] = GLR(stat_cache["PA_buy"])
|
||||
except:
|
||||
pass
|
||||
queue.put(summary)
|
||||
break
|
||||
elif len(info) == 0:
|
||||
continue
|
||||
else:
|
||||
df = info.pop("df")
|
||||
res = info.pop("res")
|
||||
ins = df.index[0][0]
|
||||
if ins not in df_cache:
|
||||
df_cache[ins] = (
|
||||
[],
|
||||
[],
|
||||
(pd.read_pickle(order_dir + ins + ".pkl.target")['amount'] != 0).sum(),
|
||||
)
|
||||
df_cache[ins][0].append(df)
|
||||
df_cache[ins][1].append(res)
|
||||
if len(df_cache[ins][0]) == df_cache[ins][2]:
|
||||
pd.concat(df_cache[ins][0]).to_pickle(log_dir + ins + ".log")
|
||||
pd.concat(df_cache[ins][1]).to_pickle(log_dir + ins + ".pkl")
|
||||
del df_cache[ins]
|
||||
for k, v in info.items():
|
||||
if k not in stat_cache:
|
||||
stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
stat_cache[k] += list(v)
|
||||
else:
|
||||
stat_cache[k].append(v)
|
||||
|
||||
def reset(self):
|
||||
""" """
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
assert self.queue.empty()
|
||||
self.child = Process(target=self._worker, args=(self.log_dir, self.order_dir, self.queue), daemon=True,)
|
||||
self.child.start()
|
||||
|
||||
def set_step(self, step):
|
||||
|
||||
self.log_dir = f"{self.raw_log_dir}{step}/"
|
||||
self.reset()
|
||||
|
||||
def __call__(self, infos):
|
||||
for info in infos:
|
||||
if "env_id" in info:
|
||||
info.pop("env_id")
|
||||
self.update(infos)
|
||||
|
||||
def update(self, infos):
|
||||
"""store values in info into the logger"""
|
||||
for info in infos:
|
||||
self.queue.put(info, block=True)
|
||||
|
||||
def summary(self):
|
||||
""":return: The mean and std of values in infos stored in logger"""
|
||||
summary = {}
|
||||
self.queue.put("stop", block=True)
|
||||
self.child.join()
|
||||
self.child.close()
|
||||
assert self.queue.qsize() == 1
|
||||
summary = self.queue.get()
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
class InfoLogger(DFLogger):
|
||||
""" """
|
||||
|
||||
def __init__(self, *args):
|
||||
self.stat_cache = {}
|
||||
self.queue = Queue(10000)
|
||||
self.child = Process(target=self._worker, args=(self.queue,), daemon=True)
|
||||
self.child.start()
|
||||
|
||||
def _worker(logdir, queue):
|
||||
stat_cache = {}
|
||||
while True:
|
||||
info = queue.get(block=True)
|
||||
if info == "stop":
|
||||
summary = {}
|
||||
for k, v in stat_cache.items():
|
||||
if not k.startswith("money"):
|
||||
summary[k + "_std"] = np.nanstd(v)
|
||||
summary[k + "_mean"] = np.nanmean(v)
|
||||
try:
|
||||
for k in ["PR_sell", "ffr_sell", "PA_sell"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["PR_buy", "ffr_buy", "PA_buy"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money_buy"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
for k in ["obs0_PR", "ffr", "PA"]:
|
||||
summary["weighted_" + k] = np.average(stat_cache[k], weights=stat_cache["money"])
|
||||
except:
|
||||
pass
|
||||
summary["GLR"] = GLR(stat_cache["PA"])
|
||||
try:
|
||||
summary["GLR_sell"] = GLR(stat_cache["PA_sell"])
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
summary["GLR_buy"] = GLR(stat_cache["PA_buy"])
|
||||
except:
|
||||
pass
|
||||
queue.put(summary)
|
||||
stat_cache = {}
|
||||
time.sleep(5)
|
||||
continue
|
||||
if len(info) == 0:
|
||||
continue
|
||||
for k, v in info.items():
|
||||
if k == "res" or k == "df":
|
||||
continue
|
||||
if k not in stat_cache:
|
||||
stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
stat_cache[k] += list(v)
|
||||
else:
|
||||
stat_cache[k].append(v)
|
||||
|
||||
def _update(self, info):
|
||||
if len(info) == 0:
|
||||
return
|
||||
ins = df.index[0][0]
|
||||
for k, v in info.items():
|
||||
if k not in self.stat_cache:
|
||||
self.stat_cache[k] = []
|
||||
if hasattr(v, "__len__"):
|
||||
self.stat_cache[k] += list(v)
|
||||
else:
|
||||
self.stat_cache[k].append(v)
|
||||
|
||||
def summary(self):
|
||||
""" """
|
||||
while not self.queue.empty():
|
||||
# print('not empty')
|
||||
# print(self.queue.qsize())
|
||||
time.sleep(1)
|
||||
self.queue.put("stop")
|
||||
# self.child.join()
|
||||
time.sleep(1)
|
||||
while not self.queue.qsize() == 1:
|
||||
# print(self.queue.qsize())
|
||||
time.sleep(1)
|
||||
assert self.queue.qsize() == 1
|
||||
summary = self.queue.get()
|
||||
|
||||
return summary
|
||||
|
||||
def set_step(self, step):
|
||||
return
|
||||
135
examples/trade/main.py
Normal file
135
examples/trade/main.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import re
|
||||
import os
|
||||
import argparse
|
||||
import yaml
|
||||
from executor import Executor
|
||||
import warnings
|
||||
import redis
|
||||
import subprocess
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from util import merge_dicts
|
||||
|
||||
loader = yaml.FullLoader
|
||||
loader.add_implicit_resolver(
|
||||
"tag:yaml.org,2002:float",
|
||||
re.compile(
|
||||
"""^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$""",
|
||||
re.X,
|
||||
),
|
||||
list("-+0123456789."),
|
||||
)
|
||||
|
||||
|
||||
def get_full_config(config, dir_name):
|
||||
while "base" in config:
|
||||
base_config = os.path.normpath(os.path.join(dir_name, config.pop("base")))
|
||||
dir_name = os.path.dirname(base_config)
|
||||
with open(base_config, "r") as f:
|
||||
base_config = yaml.load(base_config, Loader=yaml.FullLoader)
|
||||
config = merge_dicts(base_config, config)
|
||||
return config
|
||||
|
||||
|
||||
def run(config):
|
||||
log_dir = config["log_dir"]
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
with open(log_dir + "/config.yml", "w") as f:
|
||||
yaml.dump(config, f)
|
||||
executor = Executor(**config)
|
||||
if config["task"] == "train":
|
||||
return executor.train(**config["optim"])
|
||||
elif config["task"] == "eval":
|
||||
return executor.eval(config["test_paths"]["order_dir"], save_res=True, logdir=config["log_dir"] + "/test/",)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", type=str)
|
||||
parser.add_argument("-n", "--index", type=int, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(os.cpu_count())
|
||||
|
||||
EXP_PATH = os.environ["EXP_PATH"]
|
||||
config_path = os.path.normpath(os.path.join(EXP_PATH, args.config))
|
||||
EXP_NAME = os.path.relpath(config_path, EXP_PATH)
|
||||
if os.path.isdir(config_path):
|
||||
if not args.index is None:
|
||||
with open(config_path + "/configs.yml") as f:
|
||||
config_list = list(yaml.load_all(f, Loader=loader))
|
||||
config = config_list[args.index]
|
||||
if "PT_OUTPUT_DIR" in os.environ:
|
||||
config["log_dir"] = os.environ["PT_OUTPUT_DIR"]
|
||||
else:
|
||||
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
||||
config = get_full_config(config, config_path)
|
||||
run(config)
|
||||
else:
|
||||
redis_server = redis.Redis(
|
||||
host=os.environ["REDIS_SERVER"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
db=0,
|
||||
charset="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
with open(config_path + "/configs.yml") as f:
|
||||
config_list = list(yaml.load_all(f, Loader=loader))
|
||||
config_num = len(config_list)
|
||||
if not redis_server.exists(EXP_NAME):
|
||||
for i in range(config_num):
|
||||
redis_server.rpush(EXP_NAME, i)
|
||||
redis_server.set(f"{EXP_NAME}_{i}", "Pending")
|
||||
else:
|
||||
if redis_server.llen(EXP_NAME) == 0:
|
||||
for i in range(config_num):
|
||||
if (
|
||||
not redis_server.exists(f"{EXP_NAME}_{i}")
|
||||
or redis_server.get(f"{EXP_NAME}_{i}") == "Failed"
|
||||
):
|
||||
redis_server.rpush(EXP_NAME, i)
|
||||
redis_server.set(f"{EXP_NAME}_{i}", "Pending")
|
||||
print(f"Starting..., {redis_server.llen(EXP_NAME)} trails to run")
|
||||
while True:
|
||||
index = redis_server.lpop(EXP_NAME)
|
||||
if index is None:
|
||||
print("All done")
|
||||
break
|
||||
index = int(index)
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Running")
|
||||
print(f"Trail_{index} is running")
|
||||
try:
|
||||
res = subprocess.run(["python", "main.py", "--config", args.config, "--index", str(index),],)
|
||||
except KeyboardInterrupt:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
||||
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
break
|
||||
if res.returncode == 0:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Finished")
|
||||
print(f"Finish running one trail, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
else:
|
||||
redis_server.set(f"{EXP_NAME}_{index}", "Failed")
|
||||
print(f"Trail_{index} has failed, {redis_server.llen(EXP_NAME)} trails to run")
|
||||
|
||||
elif os.path.isfile(config_path):
|
||||
assert config_path.endswith(".yml"), "Config file should be an yaml file"
|
||||
EXP_NAME = EXP_NAME[:-4]
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.load(f, Loader=loader)
|
||||
config = get_full_config(config, os.path.dirname(config_path))
|
||||
log_prefix = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else "../log"
|
||||
config["log_dir"] = os.path.join(log_prefix, config["log_dir"])
|
||||
run(config)
|
||||
else:
|
||||
print("The config path should be a relative path from EXP_PATH")
|
||||
5
examples/trade/network/__init__.py
Normal file
5
examples/trade/network/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .ppo import *
|
||||
from .qmodel import *
|
||||
from .teacher import *
|
||||
from .util import *
|
||||
from .opd import *
|
||||
74
examples/trade/network/opd.py
Normal file
74
examples/trade/network/opd.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class OPD_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
teacher_action = inp[:, 0]
|
||||
inp = inp[:, 1:]
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
feature = self.fc(fc_in)
|
||||
return feature, teacher_action / 2
|
||||
|
||||
|
||||
class OPD_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
feature, self.teacher_action = self.extractor(obs)
|
||||
out = self.layer_out(feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class OPD_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
feature, self.teacher_action = self.extractor(obs)
|
||||
return self.value_out(feature).squeeze(dim=-1)
|
||||
79
examples/trade/network/ppo.py
Normal file
79
examples/trade/network/ppo.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class PPO_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
# inp = torch.from_numpy(inp).to(torch.device('cpu'))
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, -19:-1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
assert not torch.isnan(cnn_out).any()
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
assert not torch.isnan(rnn_in).any()
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
assert not torch.isnan(rnn2_in).any()
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
assert not torch.isnan(rnn2_out).any()
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
assert not torch.isnan(rnn_out).any()
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
self.feature = self.fc(fc_in)
|
||||
return self.feature
|
||||
|
||||
|
||||
class PPO_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
assert not (torch.isnan(self.feature).any() | torch.isinf(self.feature).any()), f"{self.feature}"
|
||||
out = self.layer_out(self.feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class PPO_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
return self.value_out(self.feature).squeeze(dim=-1)
|
||||
52
examples/trade/network/qmodel.py
Normal file
52
examples/trade/network/qmodel.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class RNNQModel(nn.Module):
|
||||
def __init__(self, device="cpu", out_shape=10, **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, out_shape),
|
||||
)
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
inp = to_torch(obs, dtype=torch.float32, device=self.device)
|
||||
inp = inp[:, 182:]
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240]
|
||||
raw_in = torch.cat((torch.zeros_like(inp[:, : 6 * 30]), raw_in), dim=-1)
|
||||
raw_in = raw_in.reshape(-1, 30, 6).transpose(1, 2)
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2)
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 9, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0]
|
||||
rnn_out = rnn_out[torch.arange(rnn_out.size(0)), seq_len]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
out = self.fc(fc_in)
|
||||
return out, state
|
||||
69
examples/trade/network/teacher.py
Normal file
69
examples/trade/network/teacher.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Teacher_Extractor(nn.Module):
|
||||
def __init__(self, device="cpu", feature_size=180, **kargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
hidden_size = kargs["hidden_size"]
|
||||
fc_size = kargs["fc_size"]
|
||||
self.cnn_shape = kargs["cnn_shape"]
|
||||
|
||||
self.rnn = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.rnn2 = nn.GRU(64, hidden_size, batch_first=True)
|
||||
self.dnn = nn.Sequential(nn.Linear(2, 64), nn.ReLU(),)
|
||||
self.cnn = nn.Sequential(nn.Conv1d(self.cnn_shape[1], 3, 3), nn.ReLU(),)
|
||||
self.raw_fc = nn.Sequential(nn.Linear((self.cnn_shape[0] - 2) * 3, 64), nn.ReLU(),)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_size * 2, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 32), nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
inp = to_torch(inp, dtype=torch.float32, device=self.device)
|
||||
seq_len = inp[:, -1].to(torch.long)
|
||||
batch_size = inp.shape[0]
|
||||
raw_in = inp[:, : 6 * 240].reshape(-1, 30, 6).transpose(1, 2) ## public part of state
|
||||
dnn_in = inp[:, 6 * 240 : -1].reshape(batch_size, -1, 2) ## private part of state
|
||||
cnn_out = self.cnn(raw_in).view(batch_size, 8, -1)
|
||||
rnn_in = self.raw_fc(cnn_out)
|
||||
rnn2_in = self.dnn(dnn_in)
|
||||
rnn2_out = self.rnn2(rnn2_in)[0]
|
||||
rnn_out = self.rnn(rnn_in)[0][:, -1, :]
|
||||
rnn2_out = rnn2_out[torch.arange(rnn2_out.size(0)), seq_len]
|
||||
# dnn_out = self.dnn(dnn_in)
|
||||
fc_in = torch.cat((rnn_out, rnn2_out), dim=-1)
|
||||
self.feature = self.fc(fc_in)
|
||||
return self.feature
|
||||
|
||||
|
||||
class Teacher_Actor(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(32, out_shape), nn.Softmax(dim=-1))
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
out = self.layer_out(self.feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class Teacher_Critic(nn.Module):
|
||||
def __init__(self, extractor, out_shape, device=torch.device("cpu"), **kargs):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(32, 1)
|
||||
self.device = device
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
self.feature = self.extractor(obs)
|
||||
return self.value_out(self.feature).squeeze(-1)
|
||||
191
examples/trade/network/util.py
Normal file
191
examples/trade/network/util.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key):
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1])
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze() # B * l
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1) # B * l * 1
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class MaskAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key, seq_len, maxlen=9):
|
||||
# seq_len: (batch,)
|
||||
device = value.device
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1]) # (batch, 9, 64)
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1) # (batch, 9)
|
||||
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
|
||||
weight[~mask] = float("-inf")
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class TFMaskAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.get_w = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.ReLU(), nn.Linear(in_dim, 1))
|
||||
|
||||
self.fc = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU(),)
|
||||
|
||||
def forward(self, value, key, seq_len, maxlen=9):
|
||||
device = value.device
|
||||
key = key.unsqueeze(dim=1)
|
||||
length = value.shape[1]
|
||||
key = key.repeat([1, length, 1])
|
||||
weight = self.get_w(torch.cat((key, value), dim=-1)).squeeze(-1)
|
||||
mask = sequence_mask(seq_len + 1, maxlen=maxlen, device=device)
|
||||
mask = mask.repeat(1, 3) # (batch, 9*3)
|
||||
weight[~mask] = float("-inf")
|
||||
weight = weight.softmax(dim=-1).unsqueeze(dim=-1)
|
||||
out = (value * weight).sum(dim=1)
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
|
||||
class NNAttention(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.q_net = nn.Linear(in_dim, out_dim)
|
||||
self.k_net = nn.Linear(in_dim, out_dim)
|
||||
self.v_net = nn.Linear(in_dim, out_dim)
|
||||
|
||||
def forward(self, Q, K, V):
|
||||
q = self.q_net(Q)
|
||||
k = self.k_net(K)
|
||||
v = self.v_net(V)
|
||||
|
||||
attn = torch.einsum("ijk,ilk->ijl", q, k)
|
||||
attn = attn.to(Q.device)
|
||||
attn_prob = torch.softmax(attn, dim=-1)
|
||||
|
||||
attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v)
|
||||
|
||||
return attn_vec
|
||||
|
||||
|
||||
class Reshape(nn.Module):
|
||||
def __init__(self, *args):
|
||||
super(Reshape, self).__init__()
|
||||
self.shape = args
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(self.shape)
|
||||
|
||||
|
||||
class DARNN(nn.Module):
|
||||
def __init__(self, device="cpu", **kargs):
|
||||
super().__init__()
|
||||
self.emb_dim = kargs["emb_dim"]
|
||||
self.hidden_size = kargs["hidden_size"]
|
||||
self.num_layers = kargs["num_layers"]
|
||||
self.is_bidir = kargs["is_bidir"]
|
||||
self.dropout = kargs["dropout"]
|
||||
self.seq_len = kargs["seq_len"]
|
||||
self.interval = kargs["interval"]
|
||||
self.today_length = 238
|
||||
self.prev_length = 240
|
||||
self.input_length = 480
|
||||
self.input_size = 6
|
||||
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=self.input_size + self.emb_dim,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=self.is_bidir,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.prev_rnn = nn.LSTM(
|
||||
input_size=self.input_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
batch_first=True,
|
||||
bidirectional=self.is_bidir,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.fc_out = nn.Linear(in_features=self.hidden_size * 2, out_features=1)
|
||||
self.attention = NNAttention(self.hidden_size, self.hidden_size)
|
||||
self.act_out = nn.Sigmoid()
|
||||
if self.emb_dim != 0:
|
||||
self.pos_emb = nn.Embedding(self.input_length, self.emb_dim)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = inputs.view(-1, self.input_length, self.input_size) # [B, T, F]
|
||||
today_input = inputs[:, : self.today_length, :]
|
||||
today_input = torch.cat((torch.zeros_like(today_input[:, :1, :]), today_input), dim=1)
|
||||
prev_input = inputs[:, 240 : 240 + self.prev_length, :]
|
||||
if self.emb_dim != 0:
|
||||
embedding = self.pos_emb(torch.arange(end=self.today_length + 1, device=inputs.device))
|
||||
embedding = embedding.repeat([today_input.size()[0], 1, 1])
|
||||
today_input = torch.cat((today_input, embedding), dim=-1)
|
||||
prev_outs, _ = self.prev_rnn(prev_input)
|
||||
today_outs, _ = self.rnn(today_input)
|
||||
|
||||
outs = self.attention(today_outs, prev_outs, prev_outs)
|
||||
outs = torch.cat((today_outs, outs), dim=-1)
|
||||
outs = outs[:, range(0, self.seq_len * self.interval, self.interval), :]
|
||||
# outs = self.fc_out(outs).squeeze()
|
||||
return self.act_out(self.fc_out(outs).squeeze(-1)), outs
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
def __init__(self, dim1=0, dim2=1):
|
||||
super().__init__()
|
||||
self.dim1 = dim1
|
||||
self.dim2 = dim2
|
||||
|
||||
def forward(self, x):
|
||||
return x.transpose(self.dim1, self.dim2)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, *args, **kargs):
|
||||
super().__init__()
|
||||
self.attention = nn.MultiheadAttention(*args, **kargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.attention(x, x, x)[0]
|
||||
|
||||
|
||||
def onehot_enc(y, len):
|
||||
y = y.unsqueeze(-1)
|
||||
y_onehot = torch.zeros(y.shape[0], len)
|
||||
# y_onehot.zero_()
|
||||
y_onehot.scatter(1, y, 1)
|
||||
return y_onehot
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.bool, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
mask = ~(torch.ones((len(lengths), maxlen), device=device).cumsum(dim=1).t() > lengths).t()
|
||||
mask.type(dtype)
|
||||
return mask
|
||||
3
examples/trade/observation/__init__.py
Normal file
3
examples/trade/observation/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ppo_obs import *
|
||||
from .teacher_obs import *
|
||||
from .obs_rule import *
|
||||
136
examples/trade/observation/obs_rule.py
Normal file
136
examples/trade/observation/obs_rule.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
|
||||
class BaseObs(object):
|
||||
""" """
|
||||
|
||||
def __init__(self, config):
|
||||
self._observation_space = None
|
||||
|
||||
def get_space(self):
|
||||
""" """
|
||||
return self._observation_space
|
||||
|
||||
def get_obs(self, t):
|
||||
pass
|
||||
|
||||
|
||||
class RuleObs(BaseObs):
|
||||
"""The observation for minute-level rule-based agents, which consists of prediction, private state and direction information."""
|
||||
|
||||
def __init__(self, config):
|
||||
feature_size = 0
|
||||
self.features = config["features"]
|
||||
self.time_interval = config["time_interval"]
|
||||
self.max_step_num = config["max_step_num"]
|
||||
for feature in self.features:
|
||||
feature_size += feature["size"]
|
||||
|
||||
self._observation_space = Tuple(
|
||||
(
|
||||
Box(-np.inf, np.inf, shape=(feature_size,), dtype=np.float32),
|
||||
Box(-np.inf, np.inf, shape=(4,), dtype=np.float32),
|
||||
Discrete(2),
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_obs(*args, **kargs)
|
||||
|
||||
def get_feature_res(self, df_list, time, interval, whole_day=False, interval_num=8):
|
||||
"""
|
||||
This method would extract the needed feature from the feature dataframe based on the feature name
|
||||
and the description in feature config.
|
||||
|
||||
:param df_list: The dataframes of features, the order is consistent with the feature list.
|
||||
:param time: The index of current minute of the day (starting from -1).
|
||||
:param interval: The index of interval or decition making.
|
||||
:param whole_day: if True, this method would return the concatenate of all dataframe.(Default value = False)
|
||||
|
||||
"""
|
||||
predictions = []
|
||||
if whole_day:
|
||||
try:
|
||||
prediction = [df_list[i].reshape(-1) for i in range(len(df_list))]
|
||||
except:
|
||||
prediction = [df_list[i].reshape(-1) for i in range(len(df_list))]
|
||||
for i, p in enumerate(prediction):
|
||||
if len(p) < interval_num:
|
||||
prediction[i] = np.concatenate((p, np.zeros(interval_num - len(p))), axis=-1)
|
||||
# res = np.stack(prediction).transpose().reshape(-1)
|
||||
return np.concatenate(prediction)
|
||||
for i in range(len(self.features)):
|
||||
feature = self.features[i]
|
||||
df = df_list[i]
|
||||
size = feature["size"]
|
||||
if feature["type"] == "inday":
|
||||
if time == -1:
|
||||
predictions += [0.0] * size
|
||||
else:
|
||||
predictions += df[size * time : size * (time + 1)].reshape(-1).tolist()
|
||||
elif feature["type"] == "daily":
|
||||
predictions += df.reshape(-1)[:size].tolist()
|
||||
elif feature["type"] == "range":
|
||||
if time == -1:
|
||||
predictions += [0.0] * size
|
||||
else:
|
||||
predictions += df[time : size + time].reshape(-1).tolist()
|
||||
elif feature["type"] == "interval":
|
||||
if len(df[interval * size : (interval + 1) * size].reshape(-1)) == size:
|
||||
predictions += df[interval * size : (interval + 1) * size].reshape(-1).tolist()
|
||||
else:
|
||||
predictions += [0.0] * size
|
||||
elif feature["type"] == "step":
|
||||
if len(df[size * (time + 1) : size * (time + 2)].reshape(-1)) == size:
|
||||
predictions += df[size * (time + 1) : size * (time + 2)].reshape(-1).tolist()
|
||||
else:
|
||||
predictions += [0.0] * size
|
||||
|
||||
return np.array(predictions)
|
||||
|
||||
def get_obs(self, raw_df, feature_dfs, t, interval, position, target, is_buy, *args, **kargs):
|
||||
private_state = np.array([position, target, t, self.max_step_num])
|
||||
prediction_state = self.get_feature_res(feature_dfs, t, interval)
|
||||
return {
|
||||
"prediction": prediction_state,
|
||||
"private": private_state,
|
||||
"is_buy": int(is_buy),
|
||||
}
|
||||
|
||||
|
||||
class RuleInterval(RuleObs):
|
||||
"""
|
||||
The observation for interval_level rule based strategy.
|
||||
|
||||
Consist of interval prediction, private state, direction
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def get_obs(
|
||||
self,
|
||||
raw_df,
|
||||
feature_dfs,
|
||||
t,
|
||||
interval,
|
||||
position,
|
||||
target,
|
||||
is_buy,
|
||||
max_step_num,
|
||||
interval_num,
|
||||
action=1.0,
|
||||
*args,
|
||||
**kargs
|
||||
):
|
||||
private_state = np.array([position, target, interval - 1, interval_num])
|
||||
prediction_state = self.get_feature_res(feature_dfs, t, interval)
|
||||
return {
|
||||
"prediction": prediction_state,
|
||||
"private": private_state,
|
||||
"is_buy": int(is_buy),
|
||||
"action": action,
|
||||
}
|
||||
28
examples/trade/observation/ppo_obs.py
Normal file
28
examples/trade/observation/ppo_obs.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
from .obs_rule import RuleObs
|
||||
|
||||
|
||||
class PPOObs(RuleObs):
|
||||
"""The observation defined in IJCAI 2020. The action of previous state is included in private state"""
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, action=0,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
|
||||
public_state = self.get_feature_res(feature_dfs, t, interval, whole_day=True)
|
||||
# market_state = feature_dfs[0].reshape(-1)[:6*240]
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num, action])
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 3 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
return np.concatenate((public_state, list_private_state, seqlen))
|
||||
55
examples/trade/observation/teacher_obs.py
Normal file
55
examples/trade/observation/teacher_obs.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from gym.spaces import Discrete, Box, Tuple, MultiDiscrete
|
||||
import math
|
||||
import json
|
||||
|
||||
from .obs_rule import RuleObs
|
||||
|
||||
|
||||
class TeacherObs(RuleObs):
|
||||
"""
|
||||
The Observation used for OPD method.
|
||||
|
||||
Consist of public state(raw feature), private state, seqlen
|
||||
|
||||
"""
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, *args, **kargs,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
public_state = self.get_feature_res(feature_dfs, t, interval, whole_day=True)
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num])
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 2 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
assert not (
|
||||
np.isnan(list_private_state).any() | np.isinf(list_private_state).any()
|
||||
), f"{private_state}, {target}"
|
||||
assert not (np.isnan(public_state).any() | np.isinf(public_state).any()), f"{public_state}"
|
||||
return np.concatenate((public_state, list_private_state, seqlen))
|
||||
|
||||
|
||||
class RuleTeacher(RuleObs):
|
||||
""" """
|
||||
|
||||
def get_obs(
|
||||
self, raw_df, feature_dfs, t, interval, position, target, is_buy, max_step_num, interval_num, *args, **kargs,
|
||||
):
|
||||
if t == -1:
|
||||
self.private_states = []
|
||||
public_state = feature_dfs[0].reshape(-1)[: 6 * 240]
|
||||
private_state = np.array([position / target, (t + 1) / max_step_num])
|
||||
teacher_action = self.get_feature_res(feature_dfs, t, interval)[-self.features[1]["size"] :]
|
||||
self.private_states.append(private_state)
|
||||
list_private_state = np.concatenate(self.private_states)
|
||||
list_private_state = np.concatenate(
|
||||
(list_private_state, [0.0] * 2 * (interval_num + 1 - len(self.private_states)),)
|
||||
)
|
||||
seqlen = np.array([interval])
|
||||
return np.concatenate((teacher_action, public_state, list_private_state, seqlen))
|
||||
62
examples/trade/order_gen.py
Normal file
62
examples/trade/order_gen.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import time
|
||||
import datetime
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
data_path = '../data/'
|
||||
in_dir = os.path.join(data_path, 'backtest/')
|
||||
|
||||
### create order folders ####
|
||||
|
||||
def generate_order(df, start, end):
|
||||
# df['date'] = df.index.map(lambda x: x[1].date())
|
||||
# df.set_index('date', append=True, inplace=True)
|
||||
df = df.groupby('date').take(range(start, end)).droplevel(level=0)
|
||||
div = df['$volume0'].rolling((end - start)*60).mean().shift(1).groupby(level='date').transform('first')
|
||||
order = df.groupby(level=(2, 0)).mean().dropna()
|
||||
order = pd.DataFrame(order)
|
||||
order['amount'] = np.random.lognormal(-3.28, 1.14) * order['$volume0']
|
||||
order['order_type'] = 0
|
||||
order = order.drop(columns=["$volume0", "$vwap0"])
|
||||
return order
|
||||
|
||||
def w_order(f, start, end):
|
||||
df = pd.read_pickle(in_dir + f)
|
||||
#df['date'] = df.index.get_level_values(1).map(lambda x: x.date())
|
||||
#df = df.set_index('date', append=True, drop=True)
|
||||
|
||||
order = generate_order(df, start, end)
|
||||
order_train = order[order.index.get_level_values(0) < '2020-12-01']
|
||||
order_test = order[order.index.get_level_values(0) >= '2020-12-01']
|
||||
order_valid = order_test[order_test.index.get_level_values(0) < '2021-01-01']
|
||||
order_test = order_test[order_test.index.get_level_values(0) >= '2021-01-01']
|
||||
if len(order_train) > 0:
|
||||
order_train.to_pickle(train_path + f[:-9] + '.target')
|
||||
if len(order_valid) > 0:
|
||||
order_valid.to_pickle(valid_path + f[:-9] + '.target')
|
||||
if len(order_test) > 0:
|
||||
order_test.to_pickle(test_path + f[:-9] + '.target')
|
||||
if len(order) > 0:
|
||||
order.to_pickle(all_path + f[:-9] + '.target')
|
||||
return 0
|
||||
|
||||
train_path = os.path.join(data_path, "order/train/")
|
||||
if not os.path.exists(train_path):
|
||||
os.makedirs(train_path)
|
||||
|
||||
valid_path = os.path.join(data_path, "order/valid/")
|
||||
if not os.path.exists(valid_path):
|
||||
os.makedirs(valid_path)
|
||||
|
||||
test_path = os.path.join(data_path, "order/test/")
|
||||
if not os.path.exists(test_path):
|
||||
os.makedirs(test_path)
|
||||
|
||||
all_path = os.path.join(data_path, "order/all/")
|
||||
if not os.path.exists(all_path):
|
||||
os.makedirs(all_path)
|
||||
|
||||
res = Parallel(n_jobs=64)(delayed(w_order)(f, 0, 239) for f in os.listdir(in_dir))
|
||||
print(sum(res))
|
||||
2
examples/trade/policy/__init__.py
Normal file
2
examples/trade/policy/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .ppo_supervision import *
|
||||
from .ppo import *
|
||||
255
examples/trade/policy/ppo.py
Normal file
255
examples/trade/policy/ppo.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import to_torch
|
||||
from numba import njit
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import to_numpy, to_torch_as
|
||||
|
||||
|
||||
def _episodic_return(
|
||||
v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray, gamma: float, gae_lambda: float,
|
||||
) -> np.ndarray:
|
||||
"""Numba speedup: 4.1s -> 0.057s."""
|
||||
returns = np.roll(v_s_, 1)
|
||||
m = (1.0 - done) * gamma
|
||||
delta = rew + v_s_ * m - returns
|
||||
m *= gae_lambda
|
||||
gae = 0.0
|
||||
for i in range(len(rew) - 1, -1, -1):
|
||||
gae_new = delta[i] + m[i] * gae
|
||||
gae = gae_new
|
||||
returns[i] += gae
|
||||
return returns
|
||||
|
||||
|
||||
class PPO(PGPolicy):
|
||||
""" The PPO policy with Teacher supervision"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
teacher=None,
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = 0.2,
|
||||
vf_clip_para=10.0,
|
||||
vf_coef: float = 0.5,
|
||||
kl_coef=0.5,
|
||||
kl_target=0.01,
|
||||
ent_coef: float = 0.01,
|
||||
sup_coef=0.1,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
self._vf_clip_para = vf_clip_para
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.optim = optim
|
||||
self.sup_coef = sup_coef
|
||||
self.kl_target = kl_target
|
||||
self.kl_coef = kl_coef
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert dual_clip is None or dual_clip > 1, "Dual-clip PPO parameter should greater than 1."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
if not teacher is None:
|
||||
self.teacher = torch.load(teacher, map_location=torch.device("cpu"))
|
||||
self.teacher.to(self.actor.device)
|
||||
self.teacher.actor.extractor.device = self.actor.device
|
||||
else:
|
||||
self.teacher = None
|
||||
|
||||
@staticmethod
|
||||
def compute_episodic_return(
|
||||
batch: Batch,
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
rew_norm: bool = False,
|
||||
) -> Batch:
|
||||
"""Compute returns over given full-length episodes.
|
||||
Implementation of Generalized Advantage Estimator (arXiv:1506.02438).
|
||||
:param batch: a data batch which contains several full-episode data
|
||||
chronologically.
|
||||
:type batch: :class:`~tianshou.data.Batch`
|
||||
:param v_s_: the value function of all next states :math:`V(s')`.
|
||||
:type v_s_: numpy.ndarray
|
||||
:param float gamma: the discount factor, should be in [0, 1], defaults
|
||||
to 0.99.
|
||||
:param float gae_lambda: the parameter for Generalized Advantage
|
||||
Estimation, should be in [0, 1], defaults to 0.95.
|
||||
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
|
||||
to False.
|
||||
:return: a Batch. The result will be stored in batch.returns as a numpy
|
||||
array with shape (bsz, ).
|
||||
"""
|
||||
rew = batch.rew
|
||||
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten())
|
||||
assert not np.isnan(v_s_).any()
|
||||
assert not np.isnan(rew).any()
|
||||
assert not np.isnan(batch.done).any()
|
||||
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
|
||||
assert not np.isnan(returns).any()
|
||||
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
|
||||
returns = (returns - returns.mean()) / returns.std()
|
||||
assert not np.isnan(returns).any()
|
||||
batch.returns = returns
|
||||
return batch
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
assert not np.isnan(batch.rew).any()
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
else:
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
assert not np.isnan(v_).any()
|
||||
return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch:
|
||||
"""Compute action over the given batch data."""
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
if self.training:
|
||||
try:
|
||||
act = dist.sample()
|
||||
except:
|
||||
print(logits)
|
||||
act = dist.sample()
|
||||
else:
|
||||
act = torch.argmax(logits, dim=1)
|
||||
if self._range:
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses, kl_losses = [], [], [], [], []
|
||||
if self.teacher is not None:
|
||||
supervision_losses = []
|
||||
v = []
|
||||
old_log_prob = []
|
||||
feature = []
|
||||
old_logits = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
v.append(self.critic(b.obs))
|
||||
b_ = self(b)
|
||||
dist = b_.dist
|
||||
logits = b_.logits
|
||||
old_log_prob.append(dist.log_prob(to_torch_as(b.act, v[0])))
|
||||
old_logits.append(logits)
|
||||
if not self.teacher is None:
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
self.teacher(b)
|
||||
feature.append(self.teacher.actor.feature)
|
||||
batch.old_feature = torch.cat(feature, dim=0)
|
||||
batch.old_logits = torch.cat(old_logits, dim=0)
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0]).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
dist = self(b).dist
|
||||
value = self.critic(b.obs)
|
||||
if not self.teacher is None:
|
||||
feature = self.actor.feature
|
||||
# print(feature.pow(2).mean())
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(-self._vf_clip_para, self._vf_clip_para)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
if not self.teacher is None:
|
||||
supervision_loss = (b.old_feature - feature).pow(2).mean()
|
||||
supervision_losses.append(supervision_loss.item())
|
||||
kl = torch.distributions.kl.kl_divergence(self.dist_fn(b.old_logits), dist)
|
||||
kl_loss = kl.mean()
|
||||
kl_losses.append(kl_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + self.kl_coef * kl_loss
|
||||
if self.teacher is not None:
|
||||
loss += self.sup_coef * supervision_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
cur_kl = np.mean(kl_losses)
|
||||
if cur_kl > 2.0 * self.kl_target:
|
||||
self.kl_coef *= 1.5
|
||||
elif cur_kl < 0.5 * self.kl_target:
|
||||
self.kl_coef *= 0.5
|
||||
res = {
|
||||
"loss/total_loss": losses,
|
||||
"loss/policy": clip_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"loss/entropy": ent_losses,
|
||||
"loss/kl": kl_losses,
|
||||
}
|
||||
if not self.teacher is None:
|
||||
res["loss/supervision"] = supervision_losses
|
||||
return res
|
||||
|
||||
|
||||
Student_new = PPO
|
||||
187
examples/trade/policy/ppo_supervision.py
Normal file
187
examples/trade/policy/ppo_supervision.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import to_torch
|
||||
from numba import njit
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
from util import to_numpy, to_torch_as
|
||||
|
||||
from .ppo import _episodic_return
|
||||
|
||||
|
||||
class PPO_sup(PGPolicy):
|
||||
"""The PPO policy with a log-likelihood supervision loss"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = 0.2,
|
||||
vf_clip_para=10.0,
|
||||
vf_coef: float = 0.5,
|
||||
kl_coef=0.5,
|
||||
kl_target=0.01,
|
||||
ent_coef: float = 0.01,
|
||||
sup_coef=0.1,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
self._vf_clip_para = vf_clip_para
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.optim = optim
|
||||
self.sup_coef = sup_coef
|
||||
self.kl_target = kl_target
|
||||
self.kl_coef = kl_coef
|
||||
self._batch = 64
|
||||
assert 0 <= gae_lambda <= 1, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert dual_clip is None or dual_clip > 1, "Dual-clip PPO parameter should greater than 1."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
else:
|
||||
v_ = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
return self.compute_episodic_return(batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch:
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
if self.training:
|
||||
act = dist.sample()
|
||||
else:
|
||||
act = torch.argmax(logits, dim=1)
|
||||
if self._range:
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses, kl_losses, supervision_losses = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
v = []
|
||||
old_log_prob = []
|
||||
teacher_action = []
|
||||
old_logits = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(batch_size, shuffle=False):
|
||||
v.append(self.critic(b.obs))
|
||||
b_ = self(b)
|
||||
dist = b_.dist
|
||||
logits = b_.logits
|
||||
old_log_prob.append(dist.log_prob(to_torch_as(b.act, v[0])))
|
||||
old_logits.append(logits)
|
||||
teacher_action.append(self.actor.teacher_action)
|
||||
|
||||
batch.teacher_action = torch.cat(teacher_action, dim=0).to(torch.long)
|
||||
batch.old_logits = torch.cat(old_logits, dim=0)
|
||||
batch.v = torch.cat(v, dim=0) # old value
|
||||
batch.act = to_torch_as(batch.act, v[0])
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
batch.returns = to_torch_as(batch.returns, v[0]).reshape(batch.v.shape)
|
||||
if self._rew_norm:
|
||||
mean, std = batch.returns.mean(), batch.returns.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.returns = (batch.returns - mean) / std
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size):
|
||||
res = self(b)
|
||||
logits = res.logits
|
||||
dist = res.dist
|
||||
value = self.critic(b.obs)
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean()
|
||||
else:
|
||||
clip_loss = -torch.min(surr1, surr2).mean()
|
||||
clip_losses.append(clip_loss.item())
|
||||
if self._value_clip:
|
||||
v_clip = b.v + (value - b.v).clamp(-self._vf_clip_para, self._vf_clip_para)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = (b.returns - value).pow(2).mean()
|
||||
supervision_loss = F.nll_loss(logits.log(), b.teacher_action)
|
||||
supervision_losses.append(supervision_loss.item())
|
||||
kl = torch.distributions.kl.kl_divergence(self.dist_fn(b.old_logits), dist)
|
||||
kl_loss = kl.mean()
|
||||
kl_losses.append(kl_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + self.kl_coef * kl_loss
|
||||
loss += self.sup_coef * supervision_loss
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
if hasattr(self.actor, "callback"):
|
||||
self.actor.callback()
|
||||
cur_kl = np.mean(kl_losses)
|
||||
if cur_kl > 2.0 * self.kl_target:
|
||||
self.kl_coef *= 1.5
|
||||
elif cur_kl < 0.5 * self.kl_target:
|
||||
self.kl_coef *= 0.5
|
||||
res = {
|
||||
"loss/total_loss": losses,
|
||||
"loss/policy": clip_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"loss/entropy": ent_losses,
|
||||
"loss/kl": kl_losses,
|
||||
"loss/supervision": supervision_losses,
|
||||
}
|
||||
return res
|
||||
10
examples/trade/requirements.txt
Normal file
10
examples/trade/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
gym==0.17.3
|
||||
torch==1.6.0
|
||||
numba==0.51.2
|
||||
numpy==1.19.1
|
||||
pandas==1.1.3
|
||||
tqdm==4.50.2
|
||||
tianshou==0.3.0.post1
|
||||
env==0.1.0
|
||||
PyYAML==5.4.1
|
||||
redis==3.5.3
|
||||
4
examples/trade/reward/__init__.py
Normal file
4
examples/trade/reward/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import *
|
||||
from .pa_penalty import *
|
||||
from .ppo_reward import *
|
||||
from .vp_penalty import *
|
||||
38
examples/trade/reward/base.py
Normal file
38
examples/trade/reward/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Abs_Reward(object):
|
||||
"""The abstract class for Reward."""
|
||||
|
||||
def __init__(self, config):
|
||||
return
|
||||
|
||||
def get_reward(self):
|
||||
""":return: reward"""
|
||||
reward = 0
|
||||
return reward
|
||||
|
||||
def __call__(self, *args, **kargs):
|
||||
return self.get_reward(*args, **kargs)
|
||||
|
||||
def isinstant(self):
|
||||
""":return: Whether the reward should be given at every timestep or only at the end of this episode."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Instant_Reward(Abs_Reward):
|
||||
def __init__(self, config):
|
||||
self.ffr_ratio = config["ffr_ratio"]
|
||||
self.vvr_ratio = config["vvr_ratio"]
|
||||
|
||||
def isinstant(self):
|
||||
return True
|
||||
|
||||
|
||||
class EndEpisode_Reward(Abs_Reward):
|
||||
def __init__(self, config):
|
||||
self.ffr_ratio = config["ffr_ratio"]
|
||||
self.vvr_ratio = config["vvr_ratio"]
|
||||
|
||||
def isinstant(self):
|
||||
return False
|
||||
14
examples/trade/reward/pa_penalty.py
Normal file
14
examples/trade/reward/pa_penalty.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import numpy as np
|
||||
from .base import Instant_Reward
|
||||
|
||||
|
||||
class PA_Penalty(Instant_Reward):
|
||||
"""Reward: (Abs(tt_ratio_t - 1) * 10000 * v_t / target - v_t^2 * penalty) / 100"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.penalty = config["penalty"]
|
||||
|
||||
def get_reward(self, performance_raise, v_t, target, PA_t, *args):
|
||||
reward = PA_t * v_t / target
|
||||
reward -= self.penalty * (v_t / target) ** 2
|
||||
return reward / 100
|
||||
22
examples/trade/reward/ppo_reward.py
Normal file
22
examples/trade/reward/ppo_reward.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
from .base import Abs_Reward
|
||||
|
||||
|
||||
class PPO_Reward(Abs_Reward):
|
||||
"""The reward function defined in IJCAI 2020"""
|
||||
|
||||
def __init__(self, *args):
|
||||
pass
|
||||
|
||||
def isinstant(self):
|
||||
return False
|
||||
|
||||
def get_reward(self, performace_raise, ffr, this_tt_ratio, is_buy):
|
||||
if is_buy:
|
||||
this_tt_ratio = 1 / this_tt_ratio
|
||||
if this_tt_ratio < 1:
|
||||
return -1.0
|
||||
elif this_tt_ratio < 1.1:
|
||||
return 0.0
|
||||
else:
|
||||
return 1.0
|
||||
37
examples/trade/reward/vp_penalty.py
Normal file
37
examples/trade/reward/vp_penalty.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
from .base import Instant_Reward
|
||||
|
||||
|
||||
class VP_Penalty_small(Instant_Reward):
|
||||
"""Reward: (Abs(vv_ratio_t - 1) * 10000 - v_t^2 * penalty) / 100"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.penalty = config["penalty"]
|
||||
|
||||
def get_reward(self, performance_raise, v_t, target, *args):
|
||||
"""
|
||||
|
||||
:param performance_raise: Abs(vv_ratio_t - 1) * 10000.
|
||||
:param target: Target volume
|
||||
:param v_t: The traded volume
|
||||
"""
|
||||
assert target > 0
|
||||
reward = performance_raise * v_t / target
|
||||
reward -= self.penalty * (v_t / target) ** 2
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"{performance_raise}, {v_t}, {target}"
|
||||
return reward / 100
|
||||
|
||||
|
||||
class VP_Penalty_small_vec(VP_Penalty_small):
|
||||
def get_reward(self, performance_raise, v_t, target, *args):
|
||||
"""
|
||||
|
||||
:param performance_raise: Abs(vv_ratio_t - 1) * 10000.
|
||||
:param target: Target volume
|
||||
:param v_t: The traded volume
|
||||
"""
|
||||
assert target > 0
|
||||
reward = performance_raise * v_t.sum() / target
|
||||
reward -= self.penalty * ((v_t / target) ** 2).sum()
|
||||
assert not (np.isnan(reward) or np.isinf(reward)), f"{performance_raise}, {v_t}, {target}"
|
||||
return reward / 100
|
||||
1
examples/trade/sampler/__init__.py
Normal file
1
examples/trade/sampler/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .single_sampler import *
|
||||
184
examples/trade/sampler/single_sampler.py
Normal file
184
examples/trade/sampler/single_sampler.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from multiprocessing.context import Process
|
||||
from multiprocessing import Queue
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append("..")
|
||||
|
||||
|
||||
def toArray(data):
|
||||
if type(data) == np.ndarray:
|
||||
return data
|
||||
|
||||
elif type(data) == list:
|
||||
data = np.array(data)
|
||||
return data
|
||||
|
||||
elif type(data) == pd.DataFrame:
|
||||
share_index = toArray(data.index)
|
||||
share_value = toArray(data.values)
|
||||
share_colmns = toArray(data.columns)
|
||||
return share_index, share_value, share_colmns
|
||||
|
||||
else:
|
||||
try:
|
||||
share_array = np.array(data)
|
||||
return share_array
|
||||
except:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Sampler:
|
||||
"""The sampler for training of single-assert RL."""
|
||||
|
||||
def __init__(self, config):
|
||||
self.raw_dir = config["raw_dir"] + "/"
|
||||
self.order_dir = config["order_dir"] + "/"
|
||||
self.ins_list = [f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")]
|
||||
self.features = config["features"]
|
||||
self.queue = Queue(1000)
|
||||
self.child = None
|
||||
self.ins = None
|
||||
self.raw_df = None
|
||||
self.df_list = None
|
||||
self.order_df = None
|
||||
|
||||
@staticmethod
|
||||
def _worker(order_dir, raw_dir, features, ins_list, queue):
|
||||
ins = None
|
||||
index = 0
|
||||
date_list = []
|
||||
while True:
|
||||
if ins is None or index == len(date_list):
|
||||
ins = np.random.choice(ins_list, 1)[0]
|
||||
# print(ins)
|
||||
order_df = pd.read_pickle(order_dir + ins + ".pkl.target")
|
||||
feature_df_list = []
|
||||
for feature in features:
|
||||
feature_df_list.append(pd.read_pickle(f"{feature['loc']}/{ins}.pkl"))
|
||||
raw_df = pd.read_pickle(raw_dir + ins + ".pkl.backtest")
|
||||
date_list = order_df.index.get_level_values(0).tolist()
|
||||
index = 0
|
||||
date = date_list[index]
|
||||
day_order_df = order_df.iloc[index]
|
||||
target = day_order_df["amount"]
|
||||
index += 1
|
||||
if target == 0:
|
||||
continue
|
||||
day_feature_dfs = []
|
||||
day_raw_df = raw_df.loc[pd.IndexSlice[ins, :, date]]
|
||||
is_buy = bool(day_order_df["order_type"])
|
||||
for df in feature_df_list:
|
||||
day_feature_dfs.append(df.loc[ins, date].values)
|
||||
day_feature_dfs = np.array(day_feature_dfs)
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(day_raw_df)
|
||||
day_feature_dfs_ = toArray(day_feature_dfs)
|
||||
queue.put(
|
||||
(ins, date, day_raw_df_value, day_raw_df_column, day_raw_df_index, day_feature_dfs_, target, is_buy,),
|
||||
block=True,
|
||||
)
|
||||
|
||||
def _sample_ins(self):
|
||||
""" """
|
||||
return np.random.choice(self.ins_list, 1)[0]
|
||||
|
||||
def reset(self):
|
||||
""" """
|
||||
if self.child is None:
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(self.order_dir, self.raw_dir, self.features, self.ins_list, self.queue,),
|
||||
daemon=True,
|
||||
)
|
||||
self.child.start()
|
||||
|
||||
def sample(self):
|
||||
""" """
|
||||
sample = self.queue.get(block=True)
|
||||
return sample
|
||||
|
||||
def stop(self):
|
||||
""" """
|
||||
try:
|
||||
self.child.terminate()
|
||||
except:
|
||||
for p in self.child:
|
||||
p.terminate()
|
||||
|
||||
|
||||
class TestSampler(Sampler):
|
||||
"""The sampler for backtest of single-assert strategies."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.ins_index = -1
|
||||
|
||||
def _sample_ins(self):
|
||||
""" """
|
||||
self.ins_index += 1
|
||||
if self.ins_index >= len(self.ins_list):
|
||||
return None
|
||||
else:
|
||||
return self.ins_list[self.ins_index]
|
||||
|
||||
@staticmethod
|
||||
def _worker(order_dir, raw_dir, features, ins_list, queue):
|
||||
for ins in ins_list:
|
||||
order_df = pd.read_pickle(order_dir + ins + ".pkl.target")
|
||||
df_list = []
|
||||
for feature in features:
|
||||
df_list.append(pd.read_pickle(f"{feature['loc']}/{ins}.pkl"))
|
||||
raw_df = pd.read_pickle(raw_dir + ins + ".pkl.backtest")
|
||||
date_list = order_df.index.get_level_values(0).tolist()
|
||||
for index in range(len(date_list)):
|
||||
date = date_list[index]
|
||||
day_df_list = []
|
||||
day_raw_df = raw_df.loc[pd.IndexSlice[ins, :, date]]
|
||||
day_order_df = order_df.iloc[index]
|
||||
target = day_order_df["amount"]
|
||||
if target == 0:
|
||||
continue
|
||||
is_buy = bool(day_order_df["order_type"])
|
||||
for df in df_list:
|
||||
day_df_list.append(df.loc[ins, date].values)
|
||||
day_feature_dfs = np.array(day_df_list)
|
||||
day_raw_df_index, day_raw_df_value, day_raw_df_column = toArray(day_raw_df)
|
||||
day_feature_dfs_ = toArray(day_feature_dfs)
|
||||
queue.put(
|
||||
(
|
||||
ins,
|
||||
date,
|
||||
day_raw_df_value,
|
||||
day_raw_df_column,
|
||||
day_raw_df_index,
|
||||
day_feature_dfs_,
|
||||
target,
|
||||
is_buy,
|
||||
),
|
||||
block=True,
|
||||
)
|
||||
for _ in range(100):
|
||||
queue.put(None)
|
||||
|
||||
def reset(self, order_dir=None):
|
||||
"""
|
||||
|
||||
reset the sampler and change self.order_dir if order_dir is not None.
|
||||
|
||||
"""
|
||||
if order_dir:
|
||||
self.order_dir = order_dir
|
||||
self.ins_list = [f[:-11] for f in os.listdir(self.order_dir) if f.endswith("target")]
|
||||
if not self.child is None:
|
||||
self.child.terminate()
|
||||
while not self.queue.empty():
|
||||
self.queue.get()
|
||||
self.child = Process(
|
||||
target=self._worker,
|
||||
args=(self.order_dir, self.raw_dir, self.features, self.ins_list, self.queue,),
|
||||
daemon=True,
|
||||
)
|
||||
self.child.start()
|
||||
28
examples/trade/teacher_feature.py
Normal file
28
examples/trade/teacher_feature.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
|
||||
data_path = '../data/'
|
||||
feature_path = os.path.join(data_path, 'feature/teacher/')
|
||||
if not os.path.exists(feature_path):
|
||||
os.makedirs(feature_path)
|
||||
|
||||
|
||||
log_file = os.path.join(os.environ.get('OUTPUT_DIR'),'example/OPDT_b/test/')
|
||||
|
||||
files = os.listdir(log_file)
|
||||
|
||||
for f in files:
|
||||
if f.endswith(".log"):
|
||||
df = pd.read_pickle(log_file + f)
|
||||
|
||||
#df['datetime'] = df.index.get_level_values(1).map(lambda x: x[1])
|
||||
df['datetime'] = df.index.get_level_values(1)
|
||||
df.set_index('datetime', append=True, drop=True, inplace=True)
|
||||
action = df['action']
|
||||
action = action.reset_index(level=1, drop=True)
|
||||
action.index = action.index.map(lambda x: (x[0], x[1], x[2].time()))
|
||||
action = action.unstack().iloc[:, ::30] * 2
|
||||
action = action.fillna(0)
|
||||
train_action = action.astype("int")
|
||||
final = train_action
|
||||
final.to_pickle(feature_path + f[:-4] + '.pkl')
|
||||
303
examples/trade/util.py
Normal file
303
examples/trade/util.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from collections import namedtuple
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
from tianshou.data import Batch
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
from typing import Union, Optional
|
||||
from numbers import Number
|
||||
|
||||
|
||||
def nan_weighted_avg(vals, weights, axis=None):
|
||||
"""
|
||||
|
||||
:param vals: The values to be averaged on.
|
||||
:param weights: The weights of weighted avrage.
|
||||
:param axis: On which axis to calculate the weighted avrage. (Default value = None)
|
||||
|
||||
"""
|
||||
assert vals.shape == weights.shape, AssertionError(f"{vals.shape} & {weights.shape}")
|
||||
vals = vals.copy()
|
||||
weights = weights.copy()
|
||||
res = (vals * weights).sum(axis=axis) / weights.sum(axis=axis)
|
||||
return np.nan_to_num(res, nan=vals[0])
|
||||
|
||||
|
||||
def robust_auc(y_true, y_pred):
|
||||
"""
|
||||
|
||||
Calculate AUC.
|
||||
|
||||
"""
|
||||
try:
|
||||
return roc_auc_score(y_true, y_pred)
|
||||
except:
|
||||
return np.nan
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""
|
||||
|
||||
:param d1: Dict 1.
|
||||
:type d1: dict
|
||||
:param d2: Dict 2.
|
||||
:returns: A new dict that is d1 and d2 deep merged.
|
||||
:rtype: dict
|
||||
|
||||
"""
|
||||
merged = copy.deepcopy(d1)
|
||||
deep_update(merged, d2, True, [])
|
||||
return merged
|
||||
|
||||
|
||||
def deep_update(
|
||||
original, new_dict, new_keys_allowed=False, whitelist=None, override_all_if_type_changes=None,
|
||||
):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the whitelist, then new subkeys can be introduced.
|
||||
|
||||
:param original: Dictionary with default values.
|
||||
:type original: dict
|
||||
:param new_dict(dict: dict): Dictionary with values to be updated
|
||||
:param new_keys_allowed: Whether new keys are allowed. (Default value = False)
|
||||
:type new_keys_allowed: bool
|
||||
:param whitelist: List of keys that correspond to dict
|
||||
values where new subkeys can be introduced. This is only at the top
|
||||
level. (Default value = None)
|
||||
:type whitelist: Optional[List[str]]
|
||||
:param override_all_if_type_changes: List of top level
|
||||
keys with value=dict, for which we always simply override the
|
||||
entire value (dict), iff the "type" key in that value dict changes. (Default value = None)
|
||||
:type override_all_if_type_changes: Optional[List[str]]
|
||||
:param new_dict:
|
||||
|
||||
"""
|
||||
whitelist = whitelist or []
|
||||
override_all_if_type_changes = override_all_if_type_changes or []
|
||||
|
||||
for k, value in new_dict.items():
|
||||
if k not in original and not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
|
||||
# Both orginal value and new one are dicts.
|
||||
if isinstance(original.get(k), dict) and isinstance(value, dict):
|
||||
# Check old type vs old one. If different, override entire value.
|
||||
if (
|
||||
k in override_all_if_type_changes
|
||||
and "type" in value
|
||||
and "type" in original[k]
|
||||
and value["type"] != original[k]["type"]
|
||||
):
|
||||
original[k] = value
|
||||
# Whitelisted key -> ok to add new subkeys.
|
||||
elif k in whitelist:
|
||||
deep_update(original[k], value, True)
|
||||
# Non-whitelisted key.
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed)
|
||||
# Original value not a dict OR new value not a dict:
|
||||
# Override entire value.
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
||||
|
||||
def get_seqlen(done_seq):
|
||||
"""
|
||||
|
||||
:param done_seq:
|
||||
|
||||
"""
|
||||
seqlen = []
|
||||
length = 0
|
||||
for i, done in enumerate(done_seq):
|
||||
length += 1
|
||||
if done:
|
||||
seqlen.append(length)
|
||||
length = 0
|
||||
if length > 0:
|
||||
seqlen.append(length)
|
||||
return np.array(seqlen)
|
||||
|
||||
|
||||
def generate_seq(seqlen, list):
|
||||
"""
|
||||
|
||||
:param seqlen: param list:
|
||||
:param list:
|
||||
|
||||
"""
|
||||
res = []
|
||||
index = 0
|
||||
maxlen = np.max(seqlen)
|
||||
for i in seqlen:
|
||||
if isinstance(list, torch.Tensor):
|
||||
res.append(torch.cat((list[index : index + i], torch.zeros_like(list[: maxlen - i])), dim=0,))
|
||||
else:
|
||||
res.append(np.concatenate((list[index : index + i], np.zeros_like(list[: maxlen - i])), axis=0))
|
||||
index += i
|
||||
if isinstance(list, torch.Tensor):
|
||||
res = torch.stack(res, dim=0)
|
||||
else:
|
||||
res = np.stack(res, axis=0)
|
||||
return res
|
||||
|
||||
|
||||
def sequence_batch(batch):
|
||||
"""
|
||||
|
||||
:param batch:
|
||||
|
||||
"""
|
||||
seqlen = get_seqlen(batch.done)
|
||||
# print(seqlen.max())
|
||||
# print(len(seqlen))
|
||||
res = Batch()
|
||||
# print(batch.keys())
|
||||
|
||||
for v in batch.keys():
|
||||
if v not in ["policy", "info"]:
|
||||
res[v] = generate_seq(seqlen, batch[v])
|
||||
else:
|
||||
res[v] = batch[v]
|
||||
res.seqlen = seqlen
|
||||
return res
|
||||
|
||||
|
||||
def flatten_seq(seq, seqlen):
|
||||
"""
|
||||
|
||||
:param seq: param seqlen:
|
||||
:param seqlen:
|
||||
|
||||
"""
|
||||
res = []
|
||||
for i, length in enumerate(seqlen):
|
||||
res.append(seq[i][:length])
|
||||
if isinstance(seq, torch.Tensor):
|
||||
res = torch.cat(res, dim=0)
|
||||
else:
|
||||
res = np.concatenate(res, axis=0)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def flatten_batch(batch):
|
||||
"""
|
||||
|
||||
:param batch:
|
||||
|
||||
"""
|
||||
for v in batch.keys():
|
||||
if v in ["policy", "info", "seqlen"]:
|
||||
continue
|
||||
batch[v] = flatten_seq(batch[v], batch.seqlen)
|
||||
return batch
|
||||
|
||||
|
||||
def to_numpy(
|
||||
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[Batch:
|
||||
:param dict: param list:
|
||||
:param tuple: param np.ndarray:
|
||||
:param torch: Tensor]:
|
||||
:param x: Union[Batch:
|
||||
:param list:
|
||||
:param np.ndarray:
|
||||
:param torch.Tensor]:
|
||||
:param x: Union[Batch:
|
||||
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_numpy(v)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_numpy()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_numpy(_parse_value(x))
|
||||
except TypeError:
|
||||
x = [to_numpy(e) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
return x
|
||||
|
||||
|
||||
def to_torch(
|
||||
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[Batch:
|
||||
:param dict: param list:
|
||||
:param tuple: param np.ndarray:
|
||||
:param torch: Tensor]:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
:param int: param torch.device]: (Default value = 'cpu')
|
||||
:param x: Union[Batch:
|
||||
:param list:
|
||||
:param np.ndarray:
|
||||
:param torch.Tensor]:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
:param torch.device]: (Default value = 'cpu')
|
||||
:param x: Union[Batch:
|
||||
:param dtype: Optional[torch.dtype]: (Default value = None)
|
||||
:param device: Union[str:
|
||||
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
x = x.to(device)
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_torch(v, dtype, device)
|
||||
elif isinstance(x, Batch):
|
||||
x.to_torch(dtype, device)
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_torch(_parse_value(x), dtype, device)
|
||||
except TypeError:
|
||||
x = [to_torch(e, dtype, device) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
if issubclass(x.dtype.type, (np.bool_, np.number)):
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
else:
|
||||
raise TypeError(f"object {x} cannot be converted to torch.")
|
||||
return x
|
||||
|
||||
|
||||
def to_torch_as(x: Union[torch.Tensor, dict, Batch, np.ndarray], y: torch.Tensor) -> Union[dict, Batch, torch.Tensor]:
|
||||
"""
|
||||
|
||||
:param x: Union[torch.Tensor:
|
||||
:param dict: param Batch:
|
||||
:param np: ndarray]:
|
||||
:param y: torch.Tensor:
|
||||
:param x: Union[torch.Tensor:
|
||||
:param Batch:
|
||||
:param np.ndarray]:
|
||||
:param y: torch.Tensor:
|
||||
:param x: Union[torch.Tensor:
|
||||
:param y: torch.Tensor:
|
||||
:returns: to_torch(x, dtype=y.dtype, device=y.device)``.
|
||||
|
||||
"""
|
||||
assert isinstance(y, torch.Tensor)
|
||||
return to_torch(x, dtype=y.dtype, device=y.device)
|
||||
695
examples/trade/vecenv.py
Normal file
695
examples/trade/vecenv.py
Normal file
@@ -0,0 +1,695 @@
|
||||
import gym
|
||||
import time
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from multiprocessing.context import Process
|
||||
from multiprocessing import Array, Pipe, connection, Queue
|
||||
from typing import Any, List, Tuple, Union, Callable, Optional
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
_NP_TO_CT = {
|
||||
np.bool: ctypes.c_bool,
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
np.uint16: ctypes.c_uint16,
|
||||
np.uint32: ctypes.c_uint32,
|
||||
np.uint64: ctypes.c_uint64,
|
||||
np.int8: ctypes.c_int8,
|
||||
np.int16: ctypes.c_int16,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int64: ctypes.c_int64,
|
||||
np.float32: ctypes.c_float,
|
||||
np.float64: ctypes.c_double,
|
||||
}
|
||||
|
||||
|
||||
class ShArray:
|
||||
"""Wrapper of multiprocessing Array."""
|
||||
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(
|
||||
_NP_TO_CT[dtype.type], # type: ignore
|
||||
int(np.prod(shape)),
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
def save(self, ndarray: np.ndarray) -> None:
|
||||
"""
|
||||
|
||||
:param ndarray: np.ndarray:
|
||||
:param ndarray: np.ndarray:
|
||||
:param ndarray: np.ndarray:
|
||||
|
||||
"""
|
||||
assert isinstance(ndarray, np.ndarray)
|
||||
dst = self.arr.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||
np.copyto(dst_np, ndarray)
|
||||
|
||||
def get(self) -> np.ndarray:
|
||||
""" """
|
||||
obj = self.arr.get_obj()
|
||||
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
|
||||
|
||||
|
||||
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
||||
"""
|
||||
|
||||
:param space: gym.Space:
|
||||
:param space: gym.Space:
|
||||
:param space: gym.Space:
|
||||
|
||||
"""
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
return tuple([_setup_buf(t) for t in space.spaces])
|
||||
else:
|
||||
return ShArray(space.dtype, space.shape)
|
||||
|
||||
|
||||
def _worker(
|
||||
parent: connection.Connection,
|
||||
p: connection.Connection,
|
||||
env_fn_wrapper: CloudpickleWrapper,
|
||||
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
:param tuple: param ShArray]]: (Default value = None)
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
:param ShArray]]: (Default value = None)
|
||||
:param parent: connection.Connection:
|
||||
:param p: connection.Connection:
|
||||
:param env_fn_wrapper: CloudpickleWrapper:
|
||||
:param obs_bufs: Optional[Union[dict:
|
||||
|
||||
"""
|
||||
|
||||
def _encode_obs(obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray],) -> None:
|
||||
"""
|
||||
|
||||
:param obs: Union[dict:
|
||||
:param tuple: param np.ndarray]:
|
||||
:param buffer: Union[dict:
|
||||
:param ShArray:
|
||||
:param obs: Union[dict:
|
||||
:param np.ndarray]:
|
||||
:param buffer: Union[dict:
|
||||
:param ShArray]:
|
||||
:param obs: Union[dict:
|
||||
:param buffer: Union[dict:
|
||||
|
||||
"""
|
||||
if isinstance(obs, np.ndarray) and isinstance(buffer, ShArray):
|
||||
buffer.save(obs)
|
||||
elif isinstance(obs, tuple) and isinstance(buffer, tuple):
|
||||
for o, b in zip(obs, buffer):
|
||||
_encode_obs(o, b)
|
||||
elif isinstance(obs, dict) and isinstance(buffer, dict):
|
||||
for k in obs.keys():
|
||||
_encode_obs(obs[k], buffer[k])
|
||||
return None
|
||||
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
cmd, data = p.recv()
|
||||
except EOFError: # the pipe has been closed
|
||||
p.close()
|
||||
break
|
||||
if cmd == "step":
|
||||
obs, reward, done, info = env.step(data)
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send((obs, reward, done, info))
|
||||
elif cmd == "reset":
|
||||
obs = env.reset(data)
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send(obs)
|
||||
elif cmd == "close":
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == "render":
|
||||
p.send(env.render(**data) if hasattr(env, "render") else None)
|
||||
elif cmd == "seed":
|
||||
p.send(env.seed(data) if hasattr(env, "seed") else None)
|
||||
elif cmd == "getattr":
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
elif cmd == "toggle_log":
|
||||
env.toggle_log(data)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
class SubprocEnvWorker(EnvWorker):
|
||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env], share_memory: bool = False) -> None:
|
||||
super().__init__(env_fn)
|
||||
self.parent_remote, self.child_remote = Pipe()
|
||||
self.share_memory = share_memory
|
||||
self.buffer: Optional[Union[dict, tuple, ShArray]] = None
|
||||
if self.share_memory:
|
||||
dummy = env_fn()
|
||||
obs_space = dummy.observation_space
|
||||
dummy.close()
|
||||
del dummy
|
||||
self.buffer = _setup_buf(obs_space)
|
||||
args = (
|
||||
self.parent_remote,
|
||||
self.child_remote,
|
||||
CloudpickleWrapper(env_fn),
|
||||
self.buffer,
|
||||
)
|
||||
self.process = Process(target=_worker, args=args, daemon=True)
|
||||
self.process.start()
|
||||
self.child_remote.close()
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
self.parent_remote.send(["getattr", key])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||
""" """
|
||||
|
||||
def decode_obs(buffer: Optional[Union[dict, tuple, ShArray]]) -> Union[dict, tuple, np.ndarray]:
|
||||
"""
|
||||
|
||||
:param buffer: Optional[Union[dict:
|
||||
:param tuple: param ShArray]]:
|
||||
:param buffer: Optional[Union[dict:
|
||||
:param ShArray]]:
|
||||
:param buffer: Optional[Union[dict:
|
||||
|
||||
"""
|
||||
if isinstance(buffer, ShArray):
|
||||
return buffer.get()
|
||||
elif isinstance(buffer, tuple):
|
||||
return tuple([decode_obs(b) for b in buffer])
|
||||
elif isinstance(buffer, dict):
|
||||
return {k: decode_obs(v) for k, v in buffer.items()}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return decode_obs(self.buffer)
|
||||
|
||||
def reset(self, sample) -> Any:
|
||||
"""
|
||||
|
||||
:param sample:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["reset", sample])
|
||||
# obs = self.parent_remote.recv()
|
||||
# if self.share_memory:
|
||||
# obs = self._decode_obs()
|
||||
# return obs
|
||||
|
||||
def get_reset_result(self):
|
||||
""" """
|
||||
obs = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
workers: List["SubprocEnvWorker"], wait_num: int, timeout: Optional[float] = None,
|
||||
) -> List["SubprocEnvWorker"]:
|
||||
"""
|
||||
|
||||
:param # type: ignoreworkers: List["SubprocEnvWorker"]:
|
||||
:param wait_num: int:
|
||||
:param timeout: Optional[float]: (Default value = None)
|
||||
:param # type: ignoreworkers: List["SubprocEnvWorker"]:
|
||||
:param wait_num: int:
|
||||
:param timeout: Optional[float]: (Default value = None)
|
||||
|
||||
"""
|
||||
remain_conns = conns = [x.parent_remote for x in workers]
|
||||
ready_conns: List[connection.Connection] = []
|
||||
remain_time, t1 = timeout, time.time()
|
||||
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
|
||||
if timeout:
|
||||
remain_time = timeout - (time.time() - t1)
|
||||
if remain_time <= 0:
|
||||
break
|
||||
# connection.wait hangs if the list is empty
|
||||
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
|
||||
ready_conns.extend(new_ready_conns) # type: ignore
|
||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
"""
|
||||
|
||||
:param action: np.ndarray:
|
||||
:param action: np.ndarray:
|
||||
:param action: np.ndarray:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["step", action])
|
||||
|
||||
def toggle_log(self, log):
|
||||
self.parent_remote.send(["toggle_log", log])
|
||||
|
||||
def get_result(self,) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
""" """
|
||||
obs, rew, done, info = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
"""
|
||||
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
:param seed: Optional[int]: (Default value = None)
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["seed", seed])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
"""
|
||||
|
||||
:param **kwargs: Any:
|
||||
:param **kwargs: Any:
|
||||
|
||||
"""
|
||||
self.parent_remote.send(["render", kwargs])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def close_env(self) -> None:
|
||||
""" """
|
||||
try:
|
||||
self.parent_remote.send(["close", None])
|
||||
# mp may be deleted so it may raise AttributeError
|
||||
self.parent_remote.recv()
|
||||
self.process.join()
|
||||
except (BrokenPipeError, EOFError, AttributeError):
|
||||
pass
|
||||
# ensure the subproc is terminated
|
||||
self.process.terminate()
|
||||
|
||||
|
||||
class BaseVectorEnv(gym.Env):
|
||||
"""Base class for vectorized environments wrapper.
|
||||
Usage:
|
||||
::
|
||||
env_num = 8
|
||||
envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||
assert len(envs) == env_num
|
||||
It accepts a list of environment generators. In other words, an environment
|
||||
generator ``efn`` of a specific task means that ``efn()`` returns the
|
||||
environment of the given task, for example, ``gym.make(task)``.
|
||||
All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.
|
||||
Here are some other usages:
|
||||
::
|
||||
envs.seed(2) # which is equal to the next line
|
||||
envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env
|
||||
obs = envs.reset() # reset all environments
|
||||
obs = envs.reset([0, 5, 7]) # reset 3 specific environments
|
||||
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||
envs.render() # render all environments
|
||||
envs.close() # close all environments
|
||||
.. warning::
|
||||
If you use your own environment, please make sure the ``seed`` method
|
||||
is set up properly, e.g.,
|
||||
::
|
||||
def seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
|
||||
:param env_fns: a list of callable envs
|
||||
:param env:
|
||||
:param worker_fn: a callable worker
|
||||
:param worker: which contains the i
|
||||
:param int: wait_num
|
||||
:param env: step
|
||||
:param environments: to finish a step is time
|
||||
:param return: when
|
||||
:param simulation: in these environments
|
||||
:param is: disabled
|
||||
:param float: timeout
|
||||
:param vectorized: step it only deal with those environments spending time
|
||||
:param within: timeout
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
||||
sampler=None,
|
||||
testing: Optional[bool] = False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
self._env_fns = env_fns
|
||||
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
|
||||
# interact with the given envs (one worker <-> one env).
|
||||
self.workers = [worker_fn(fn) for fn in env_fns]
|
||||
self.worker_class = type(self.workers[0])
|
||||
assert issubclass(self.worker_class, EnvWorker)
|
||||
assert all([isinstance(w, self.worker_class) for w in self.workers])
|
||||
|
||||
self.env_num = len(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert 1 <= self.wait_num <= len(env_fns), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||
self.timeout = timeout
|
||||
assert self.timeout is None or self.timeout > 0, f"timeout is {timeout}, it should be positive if provided!"
|
||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None or testing
|
||||
self.waiting_conn: List[EnvWorker] = []
|
||||
# environments in self.ready_id is actually ready
|
||||
# but environments in self.waiting_id are just waiting when checked,
|
||||
# and they may be ready now, but this is not known until we check it
|
||||
# in the step() function
|
||||
self.waiting_id: List[int] = []
|
||||
# all environments are ready in the beginning
|
||||
self.ready_id = list(range(self.env_num))
|
||||
self.is_closed = False
|
||||
self.sampler = sampler
|
||||
self.sample_obs = None
|
||||
|
||||
def _assert_is_not_closed(self) -> None:
|
||||
""" """
|
||||
assert not self.is_closed, f"Methods of {self.__class__.__name__} cannot be called after " "close."
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
"""Switch the attribute getter depending on the key.
|
||||
Any class who inherits ``gym.Env`` will inherit some attributes, like
|
||||
``action_space``. However, we would like the attribute lookup to go
|
||||
straight into the worker (in fact, this vector env's action_space is
|
||||
always None).
|
||||
"""
|
||||
if key in [
|
||||
"metadata",
|
||||
"reward_range",
|
||||
"spec",
|
||||
"action_space",
|
||||
"observation_space",
|
||||
]: # reserved keys in gym.Env
|
||||
return self.__getattr__(key)
|
||||
else:
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __getattr__(self, key: str) -> List[Any]:
|
||||
"""Fetch a list of env attributes.
|
||||
This function tries to retrieve an attribute from each individual
|
||||
wrapped environment, if it does not belong to the wrapping vector
|
||||
environment class.
|
||||
"""
|
||||
return [getattr(worker, key) for worker in self.workers]
|
||||
|
||||
def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None) -> Union[List[int], np.ndarray]:
|
||||
"""
|
||||
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
|
||||
"""
|
||||
if id is None:
|
||||
id = list(range(self.env_num))
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
return id
|
||||
|
||||
def _assert_id(self, id: List[int]) -> None:
|
||||
"""
|
||||
|
||||
:param id: List[int]:
|
||||
:param id: List[int]:
|
||||
:param id: List[int]:
|
||||
|
||||
"""
|
||||
for i in id:
|
||||
assert i not in self.waiting_id, f"Cannot interact with environment {i} which is stepping now."
|
||||
assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}."
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None) -> np.ndarray:
|
||||
"""Reset the state of some envs and return initial observations.
|
||||
If id is None, reset the state of all the environments and return
|
||||
initial observations, otherwise reset the specific environments with
|
||||
the given id, either an int or a list.
|
||||
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param id: Optional[Union[int:
|
||||
|
||||
"""
|
||||
start_time = time.time()
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
obs = []
|
||||
stop_id = []
|
||||
for i in id:
|
||||
sample = self.sampler.sample()
|
||||
if sample is None:
|
||||
stop_id.append(i)
|
||||
else:
|
||||
self.workers[i].reset(sample)
|
||||
for i in id:
|
||||
if i in stop_id:
|
||||
obs.append(self.sample_obs)
|
||||
else:
|
||||
this_obs = self.workers[i].get_reset_result()
|
||||
if self.sample_obs is None:
|
||||
self.sample_obs = this_obs
|
||||
for j in range(len(obs)):
|
||||
if obs[j] is None:
|
||||
obs[j] = self.sample_obs
|
||||
obs.append(this_obs)
|
||||
|
||||
if len(obs) > 0:
|
||||
obs = np.stack(obs)
|
||||
# if len(stop_id)> 0:
|
||||
# obs_zero =
|
||||
# print(time.time() - start_timed)
|
||||
|
||||
return obs, stop_id
|
||||
|
||||
def toggle_log(self, log):
|
||||
for worker in self.workers:
|
||||
worker.toggle_log(log)
|
||||
|
||||
def reset_sampler(self):
|
||||
""" """
|
||||
self.sampler.reset()
|
||||
|
||||
def step(self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None) -> List[np.ndarray]:
|
||||
"""Run one timestep of some environments' dynamics.
|
||||
If id is None, run one timestep of all the environments’ dynamics;
|
||||
otherwise run one timestep for some environments with given id, either
|
||||
an int or a list. When the end of episode is reached, you are
|
||||
responsible for calling reset(id) to reset this environment’s state.
|
||||
Accept a batch of action and return a tuple (batch_obs, batch_rew,
|
||||
batch_done, batch_info) in numpy format.
|
||||
|
||||
:param numpy: ndarray action: a batch of action provided by the agent.
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:param List: int]:
|
||||
:param np: ndarray]]: (Default value = None)
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:param List[int]:
|
||||
:param np.ndarray]]: (Default value = None)
|
||||
:param action: np.ndarray:
|
||||
:param id: Optional[Union[int:
|
||||
:rtype: A tuple including four items
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if not self.is_async:
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.workers[j].send_action(action[i])
|
||||
result = []
|
||||
for j in id:
|
||||
obs, rew, done, info = self.workers[j].get_result()
|
||||
info["env_id"] = j
|
||||
result.append((obs, rew, done, info))
|
||||
else:
|
||||
if action is not None:
|
||||
self._assert_id(id)
|
||||
assert len(action) == len(id)
|
||||
for i, (act, env_id) in enumerate(zip(action, id)):
|
||||
self.workers[env_id].send_action(act)
|
||||
self.waiting_conn.append(self.workers[env_id])
|
||||
self.waiting_id.append(env_id)
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
ready_conns: List[EnvWorker] = []
|
||||
while not ready_conns:
|
||||
ready_conns = self.worker_class.wait(self.waiting_conn, self.wait_num, self.timeout)
|
||||
result = []
|
||||
for conn in ready_conns:
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
env_id = self.waiting_id.pop(waiting_index)
|
||||
obs, rew, done, info = conn.get_result()
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
return list(map(np.stack, zip(*result)))
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[Optional[List[int]]]:
|
||||
"""Set the seed for all environments.
|
||||
Accept ``None``, an int (which will extend ``i`` to
|
||||
``[i, i + 1, i + 2, ...]``) or a list.
|
||||
|
||||
:param seed: Optional[Union[int:
|
||||
:param List: int]]]: (Default value = None)
|
||||
:param seed: Optional[Union[int:
|
||||
:param List[int]]]: (Default value = None)
|
||||
:param seed: Optional[Union[int:
|
||||
:returns: The list of seeds used in this env's random number generators.
|
||||
The first value in the list should be the "main" seed, or the value
|
||||
which a reproducer pass to "seed".
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
seed_list: Union[List[None], List[int]]
|
||||
if seed is None:
|
||||
seed_list = [seed] * self.env_num
|
||||
elif isinstance(seed, int):
|
||||
seed_list = [seed + i for i in range(self.env_num)]
|
||||
else:
|
||||
seed_list = seed
|
||||
return [w.seed(s) for w, s in zip(self.workers, seed_list)]
|
||||
|
||||
def render(self, **kwargs: Any) -> List[Any]:
|
||||
"""Render all of the environments.
|
||||
|
||||
:param **kwargs: Any:
|
||||
:param **kwargs: Any:
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
if self.is_async and len(self.waiting_id) > 0:
|
||||
raise RuntimeError(f"Environments {self.waiting_id} are still stepping, cannot " "render them now.")
|
||||
return [w.render(**kwargs) for w in self.workers]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all of the environments.
|
||||
This function will be called only once (if not, it will be called
|
||||
during garbage collected). This way, ``close`` of all workers can be
|
||||
assured.
|
||||
|
||||
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
for w in self.workers:
|
||||
w.close()
|
||||
self.is_closed = True
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Redirect to self.close()."""
|
||||
if not self.is_closed:
|
||||
self.close()
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
sampler=None,
|
||||
testing=False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
"""
|
||||
|
||||
:param fn: Callable[[]:
|
||||
:param gym: Env]:
|
||||
:param fn: Callable[[]:
|
||||
:param gym.Env]:
|
||||
:param fn: Callable[[]:
|
||||
|
||||
"""
|
||||
return SubprocEnvWorker(fn, share_memory=False)
|
||||
|
||||
super().__init__(env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class ShmemVectorEnv(BaseVectorEnv):
|
||||
"""Optimized SubprocVectorEnv with shared buffers to exchange observations.
|
||||
ShmemVectorEnv has exactly the same API as SubprocVectorEnv.
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
|
||||
detailed explanation.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
sampler=None,
|
||||
testing=False,
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
"""
|
||||
|
||||
:param fn: Callable[[]:
|
||||
:param gym: Env]:
|
||||
:param fn: Callable[[]:
|
||||
:param gym.Env]:
|
||||
:param fn: Callable[[]:
|
||||
|
||||
"""
|
||||
return SubprocEnvWorker(fn, share_memory=True)
|
||||
|
||||
super().__init__(env_fns, worker_fn, sampler, testing, wait_num=wait_num, timeout=timeout)
|
||||
@@ -28,17 +28,11 @@
|
||||
"import sys, site\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"################################# NOTE #################################\n",
|
||||
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
|
||||
"# in this cell, users should RESTART the runtime in order to run the #\n",
|
||||
"# following cells successfully. #\n",
|
||||
"########################################################################\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" import qlib\n",
|
||||
"except ImportError:\n",
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
@@ -244,7 +238,9 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
@@ -363,7 +359,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
"version": "3.7.9"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
@@ -381,4 +377,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,82 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
@@ -30,7 +90,7 @@ if __name__ == "__main__":
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"benchmark": benchmark,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
@@ -39,9 +99,9 @@ if __name__ == "__main__":
|
||||
},
|
||||
}
|
||||
|
||||
# model initialization
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# NOTE: This line is optional
|
||||
# It demonstrates that the dataset can be used standalone.
|
||||
@@ -50,16 +110,14 @@ if __name__ == "__main__":
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
__version__ = "0.6.3.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
__version__ = "0.6.1.99"
|
||||
|
||||
|
||||
import os
|
||||
@@ -11,13 +10,12 @@ import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .log import get_module_logger
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
@@ -50,6 +48,7 @@ def init(default_conf="client", **kwargs):
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
from .log import get_module_logger
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
@@ -148,78 +147,7 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
"""
|
||||
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
- There is a file named `config.yaml` in qlib.
|
||||
|
||||
For example:
|
||||
If your project file system stucuture follows such a pattern
|
||||
|
||||
<project_path>/
|
||||
- config.yaml
|
||||
- ...some folders...
|
||||
- qlib/
|
||||
|
||||
This folder will return <project_path>
|
||||
|
||||
NOTE: link is not supported here.
|
||||
|
||||
|
||||
This method is often used when
|
||||
- user want to use a relative config path instead of hard-coding qlib config path in code
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError:
|
||||
If project path is not found
|
||||
"""
|
||||
if cur_path is None:
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
if cur_path == cur_path.parent:
|
||||
raise FileNotFoundError("We can't find the project path")
|
||||
cur_path = cur_path.parent
|
||||
|
||||
|
||||
def auto_init(**kwargs):
|
||||
"""
|
||||
This function will init qlib automatically with following priority
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
"""
|
||||
|
||||
try:
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
conf_type = conf.get("conf_type", "origin")
|
||||
if conf_type == "origin":
|
||||
# The type of config is just like original qlib config
|
||||
init_from_yaml_conf(conf_pp, **kwargs)
|
||||
elif conf_type == "ref":
|
||||
# This config type will be more convenient in following scenario
|
||||
# - There is a shared configure file and you don't want to edit it inplace.
|
||||
# - The shared configure may be updated later and you don't want to copy it.
|
||||
# - You have some customized config.
|
||||
qlib_conf_path = conf["qlib_cfg"]
|
||||
qlib_conf_update = conf.get("qlib_cfg_update")
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
|
||||
logger = get_module_logger("Initialization")
|
||||
logger.info(f"Auto load project config: {conf_pp}")
|
||||
|
||||
@@ -33,9 +33,6 @@ class Config:
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._config")
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.__dict__["_config"].get(key, default)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__["_config"][key] = value
|
||||
|
||||
@@ -108,7 +105,7 @@ _default_config = {
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
# This value can be reset via qlib.init
|
||||
"logging_level": logging.INFO,
|
||||
"logging_level": "INFO",
|
||||
# Global configuration of qlib log
|
||||
# logging_level can control the logging level more finely
|
||||
"logging_config": {
|
||||
@@ -127,14 +124,14 @@ _default_config = {
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": logging.DEBUG,
|
||||
"level": "DEBUG",
|
||||
"formatter": "logger_format",
|
||||
"filters": ["field_not_found"],
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
},
|
||||
# Default config for experiment manager
|
||||
# Defatult config for experiment manager
|
||||
"exp_manager": {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
@@ -143,11 +140,6 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
# Default config for MongoDB
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
"task_db_name": "default_task_db",
|
||||
},
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
@@ -193,7 +185,7 @@ MODE_CONF = {
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
# serversS(such as PAI) [auto_mount:True]
|
||||
"timeout": 100,
|
||||
"logging_level": logging.INFO,
|
||||
"logging_level": "INFO",
|
||||
"region": REG_CN,
|
||||
## Custom Operator
|
||||
"custom_ops": [],
|
||||
@@ -318,22 +310,8 @@ class QlibConfig(Config):
|
||||
# clean up experiment when python program ends
|
||||
experiment_exit_handler()
|
||||
|
||||
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
|
||||
self.reset_qlib_version()
|
||||
|
||||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
qlib.__version__ = reset_version
|
||||
else:
|
||||
qlib.__version__ = getattr(qlib, "__version__bak")
|
||||
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
|
||||
# Using __version__bak instead of __version__
|
||||
|
||||
@property
|
||||
def registered(self):
|
||||
return self._registered
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user