mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 14:01:28 +08:00
Compare commits
82 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee3d4092ae | ||
|
|
ae83f9056f | ||
|
|
c276de4040 | ||
|
|
84103c7d43 | ||
|
|
6d6c586dc2 | ||
|
|
54ef18ec4e | ||
|
|
0dfbf8c413 | ||
|
|
f9c35284e1 | ||
|
|
3974bfe746 | ||
|
|
45ebb1d0e0 | ||
|
|
103f8579c1 | ||
|
|
654033733d | ||
|
|
d224ea447e | ||
|
|
9265b66e09 | ||
|
|
d4b56d97b5 | ||
|
|
0b3b95f22f | ||
|
|
0596174b94 | ||
|
|
779b1786bd | ||
|
|
007082a112 | ||
|
|
4e380b611e | ||
|
|
1e410c99be | ||
|
|
b223c4304d | ||
|
|
f2771f1beb | ||
|
|
01bdf6c1b1 | ||
|
|
9639a8cac9 | ||
|
|
cae4c9c924 | ||
|
|
a2be6e28e9 | ||
|
|
fdbc666678 | ||
|
|
7800dd4ec9 | ||
|
|
3fa48d7017 | ||
|
|
b18132dce1 | ||
|
|
16957176a9 | ||
|
|
f0b9a807ea | ||
|
|
5ee2d9496b | ||
|
|
4f2d6b0d84 | ||
|
|
3943b7001f | ||
|
|
2593185721 | ||
|
|
7a884fa9f2 | ||
|
|
d929d4bb21 | ||
|
|
e54b019ee2 | ||
|
|
426b98a3bc | ||
|
|
82f8ff9066 | ||
|
|
7b15682c63 | ||
|
|
df36839a7f | ||
|
|
4cecaba618 | ||
|
|
63b823f343 | ||
|
|
e41c0ac90a | ||
|
|
31e9d529de | ||
|
|
5fa56703ae | ||
|
|
c6bb11fe56 | ||
|
|
3d7ebd1fe0 | ||
|
|
7313b4dad0 | ||
|
|
b70caff522 | ||
|
|
96b422a906 | ||
|
|
64130d9407 | ||
|
|
a58bc03a8e | ||
|
|
f537222ce3 | ||
|
|
c427c64845 | ||
|
|
22ff8fdc44 | ||
|
|
4efb0a75c1 | ||
|
|
052aad7982 | ||
|
|
12f05c7182 | ||
|
|
ac08468330 | ||
|
|
df9745f134 | ||
|
|
2e49a5f7c0 | ||
|
|
3ab5721448 | ||
|
|
6a94b45503 | ||
|
|
7c31012b50 | ||
|
|
334b92ace7 | ||
|
|
9a175d7507 | ||
|
|
17ea44e0cf | ||
|
|
c0ce712be9 | ||
|
|
8e81a017c1 | ||
|
|
706727988c | ||
|
|
e99224e5c2 | ||
|
|
8c8d1336de | ||
|
|
d01de411a8 | ||
|
|
28fe4d4bb4 | ||
|
|
873129aa9b | ||
|
|
3a152f9b8b | ||
|
|
2b75b41a08 | ||
|
|
00d17f0a52 |
8
.github/workflows/python-publish.yml
vendored
8
.github/workflows/python-publish.yml
vendored
@@ -12,8 +12,9 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
os: [windows-latest, macos-latest, macos-11]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -44,7 +45,8 @@ jobs:
|
||||
- name: Build wheel on Linux
|
||||
uses: RalfG/python-wheels-manylinux-build@v0.3.1-manylinux2010_x86_64
|
||||
with:
|
||||
python-versions: 'cp36-cp36m cp37-cp37m cp38-cp38'
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-versions: 'cp37-cp37m cp38-cp38'
|
||||
build-requirements: 'numpy cython'
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
|
||||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@@ -13,7 +13,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-18.04, ubuntu-20.04]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
@@ -49,15 +50,6 @@ jobs:
|
||||
pip install --upgrade cython jupyter jupyter_contrib_nbextensions numpy scipy scikit-learn # installing without this line will cause errors on GitHub Actions, while instsalling locally won't
|
||||
pip install -e .
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
if [ "$RUNNER_OS" == "Windows" ]; then
|
||||
$CONDA\\python.exe scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
else
|
||||
$CONDA/bin/python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
|
||||
6
.github/workflows/test_macos.yml
vendored
6
.github/workflows/test_macos.yml
vendored
@@ -10,10 +10,12 @@ on:
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: macos-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
os: [macos-11, macos-latest]
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
17
CHANGES.rst
17
CHANGES.rst
@@ -159,6 +159,21 @@ Version 0.5.0
|
||||
- Add baselines
|
||||
- public data crawler
|
||||
|
||||
Version greater than Version 0.5.0
|
||||
|
||||
Version 0.8.0
|
||||
--------------------
|
||||
- The backtest is greatly refactored.
|
||||
- Nested decision execution framework is supported
|
||||
- There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed
|
||||
- The trading limitation is more accurate;
|
||||
- In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`_, longing and shorting actions share the same action.
|
||||
- In `current verison <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`_, the trading limitation is different between loging and shorting action.
|
||||
- The constant is different when calculating annualized metrics.
|
||||
- `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`_
|
||||
- `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.
|
||||
- Users could chec kout the backtesting results between `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`_ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`_
|
||||
|
||||
|
||||
Other Versions
|
||||
----------------------------------
|
||||
Please refer to `Github release Notes <https://github.com/microsoft/qlib/releases>`_
|
||||
|
||||
45
README.md
45
README.md
@@ -11,6 +11,9 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| ADD model | [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 |
|
||||
| ADARNN model | [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 |
|
||||
| TCN model | [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 |
|
||||
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
@@ -79,7 +82,7 @@ At the module level, Qlib is a platform that consists of the above components. T
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Decision Generator` will generate the target trading decisions(i.e. portfolio, orders) to be executed by `Execution Env` (i.e. the trading market). There may be multiple levels of `Trading Agent` and `Execution Env` (e.g. an _order executor trading agent and intraday order execution environment_ could behave like an interday trading environment and nested in _daily portfolio management trading agent and interday trading environment_ ) |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
* The modules with hand-drawn style are under development and will be released in the future.
|
||||
@@ -100,7 +103,6 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
@@ -279,22 +281,25 @@ The automatic workflow may not suit the research workflow of all Quant researche
|
||||
# [Quant Model (Paper) Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](qlib/contrib/model/catboost_model.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural omputation 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
|
||||
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
|
||||
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
|
||||
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](examples/benchmarks/XGBoost/)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](examples/benchmarks/LightGBM/)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](examples/benchmarks/CatBoost/)
|
||||
- [MLP based on pytorch](examples/benchmarks/MLP/)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural computation 1997)](examples/benchmarks/LSTM/)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](examples/benchmarks/GRU/)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](examples/benchmarks/ALSTM)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](examples/benchmarks/GATs/)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](examples/benchmarks/SFM/)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](examples/benchmarks/TabNet/)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](examples/benchmarks/DoubleEnsemble/)
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](examples/benchmarks/TCTS/)
|
||||
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](examples/benchmarks/Transformer/)
|
||||
- [Localformer based on pytorch (Juyong Jiang, et al.)](examples/benchmarks/Localformer/)
|
||||
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](examples/benchmarks/TRA/)
|
||||
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
|
||||
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -307,7 +312,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
|
||||
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py run --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
|
||||
|
||||
## Run multiple models
|
||||
@@ -317,7 +322,7 @@ The script will create a unique virtual environment for each model, and delete t
|
||||
|
||||
Here is an example of running all the models for 10 iterations:
|
||||
```python
|
||||
python run_all_model.py 10
|
||||
python run_all_model.py run 10
|
||||
```
|
||||
|
||||
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.7.2.99
|
||||
0.8.0
|
||||
|
||||
4
docs/_static/img/Task-Gen-Recorder-Collector.svg
vendored
Normal file
4
docs/_static/img/Task-Gen-Recorder-Collector.svg
vendored
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 198 KiB |
@@ -11,7 +11,10 @@ Introduction
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.The processes of task generation, model training and combine and collect data are shown in the following figure.
|
||||
|
||||
.. image:: ../_static/img/Task-Gen-Recorder-Collector.svg
|
||||
:align: center
|
||||
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
@@ -74,6 +77,8 @@ If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to
|
||||
|
||||
Task Collecting
|
||||
===============
|
||||
Before collecting model training results, you need to use the ``qlib.init`` to specify the path of mlruns.
|
||||
|
||||
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
|
||||
|
||||
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
|
||||
@@ -82,8 +87,10 @@ To collect the results of ``task`` after training, ``Qlib`` provides `Collector
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
|
||||
|
||||
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object.
|
||||
You can set the ensembles you want in the ``Collector``'s process_list.
|
||||
Common ensembles include ``AverageEnsemble`` and ``RollingEnsemble``. Average ensemble is used to ensemble the results of different models in the same time period. Rollingensemble is used to ensemble the results of different models in the same time period
|
||||
|
||||
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
|
||||
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
.. _backtest:
|
||||
|
||||
============================================
|
||||
Intraday Trading: Model&Strategy Testing
|
||||
============================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
``Intraday Trading`` is designed to test models and strategies, which help users to check the performance of a custom model/strategy.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
``Intraday Trading`` uses ``Order Executor`` to trade and execute orders output by ``Portfolio Strategy``. ``Order Executor`` is a component in `Qlib Framework <../introduction/introduction.html#framework>`_, which can execute orders. ``VWAP Executor`` and ``Close Executor`` is supported by ``Qlib`` now. In the future, ``Qlib`` will support ``HighFreq Executor`` also.
|
||||
|
||||
|
||||
|
||||
Example
|
||||
===========================
|
||||
|
||||
Users need to generate a `prediction score`(a pandas DataFrame) with MultiIndex<instrument, datetime> and a `score` column. And users need to assign a strategy used in backtest, if strategy is not assigned,
|
||||
a `TopkDropoutStrategy` strategy with `(topk=50, n_drop=5, risk_degree=0.95, limit_threshold=0.0095)` will be used.
|
||||
If ``Strategy`` module is not users' interested part, `TopkDropoutStrategy` is enough.
|
||||
|
||||
The simple example of the default strategy is as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.contrib.evaluate import backtest
|
||||
# pred_score is the prediction score
|
||||
report, positions = backtest(pred_score, topk=50, n_drop=0.5, limit_threshold=0.0095)
|
||||
|
||||
To know more about backtesting with a specific ``Strategy``, please refer to `Portfolio Strategy <strategy.html>`_.
|
||||
|
||||
To know more about the prediction score `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
Prediction Score
|
||||
-----------------
|
||||
|
||||
The `prediction score` is a pandas DataFrame. Its index is <datetime(pd.Timestamp), instrument(str)> and it must
|
||||
contains a `score` column.
|
||||
|
||||
A prediction sample is shown as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
datetime instrument score
|
||||
2019-01-04 SH600000 -0.505488
|
||||
2019-01-04 SZ002531 -0.320391
|
||||
2019-01-04 SZ000999 0.583808
|
||||
2019-01-04 SZ300569 0.819628
|
||||
2019-01-04 SZ001696 -0.137140
|
||||
... ...
|
||||
2019-04-30 SZ000996 -1.027618
|
||||
2019-04-30 SH603127 0.225677
|
||||
2019-04-30 SH603126 0.462443
|
||||
2019-04-30 SH603133 -0.302460
|
||||
2019-04-30 SZ300760 -0.126383
|
||||
|
||||
``Forecast Model`` module can make predictions, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
Backtest Result
|
||||
------------------
|
||||
|
||||
The backtest results are in the following form:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
risk
|
||||
excess_return_without_cost mean 0.000605
|
||||
std 0.005481
|
||||
annualized_return 0.152373
|
||||
information_ratio 1.751319
|
||||
max_drawdown -0.059055
|
||||
excess_return_with_cost mean 0.000410
|
||||
std 0.005478
|
||||
annualized_return 0.103265
|
||||
information_ratio 1.187411
|
||||
max_drawdown -0.075024
|
||||
|
||||
|
||||
|
||||
- `excess_return_without_cost`
|
||||
- `mean`
|
||||
Mean value of the `CAR` (cumulative abnormal return) without cost
|
||||
- `std`
|
||||
The `Standard Deviation` of `CAR` (cumulative abnormal return) without cost.
|
||||
- `annualized_return`
|
||||
The `Annualized Rate` of `CAR` (cumulative abnormal return) without cost.
|
||||
- `information_ratio`
|
||||
The `Information Ratio` without cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) without cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
|
||||
|
||||
- `excess_return_with_cost`
|
||||
- `mean`
|
||||
Mean value of the `CAR` (cumulative abnormal return) series with cost
|
||||
- `std`
|
||||
The `Standard Deviation` of `CAR` (cumulative abnormal return) series with cost.
|
||||
- `annualized_return`
|
||||
The `Annualized Rate` of `CAR` (cumulative abnormal return) with cost.
|
||||
- `information_ratio`
|
||||
The `Information Ratio` with cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) with cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
|
||||
|
||||
|
||||
|
||||
Reference
|
||||
==============
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading <../reference/api.html#module-qlib.contrib.evaluate>`_.
|
||||
@@ -1,120 +1,31 @@
|
||||
.. _highfreq:
|
||||
|
||||
============================================
|
||||
Design of hierarchical order execution framework
|
||||
Design of Nested Decision Execution Framework for High-Frequency Trading
|
||||
============================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
In order to support reinforcement learning algorithms for high-frequency trading, a corresponding framework is required. None of the publicly available high-frequency trading frameworks now consider multi-layer trading mechanisms, and the currently designed algorithms cannot directly use existing frameworks.
|
||||
In addition to supporting the basic intraday multi-layer trading, the linkage with the day-ahead strategy is also a factor that affects the performance evaluation of the strategy. Different day strategies generate different order distributions and different patterns on different stocks. To verify that high-frequency trading strategies perform well on real trading orders, it is necessary to support day-frequency and high-frequency multi-level linkage trading. In addition to more accurate backtesting of high-frequency trading algorithms, if the distribution of day-frequency orders is considered when training a high-frequency trading model, the algorithm can also be optimized more for product-specific day-frequency orders.
|
||||
Therefore, innovation in the high-frequency trading framework is necessary to solve the various problems mentioned above, for which we designed a hierarchical order execution framework that can link daily-frequency and intra-day trading at different granularities.
|
||||
Daily trading (e.g. portfolio management) and intraday trading (e.g. orders execution) are two hot topics in Quant investment and usually studied separately.
|
||||
|
||||
To get the join trading performance of daily and intraday trading, they must interact with each other and run backtest jointly.
|
||||
In order to support the joint backtest strategies in multiple levels, a corresponding framework is required. None of the publicly available high-frequency trading frameworks considers multi-level joint trading, which make the backtesting aforementioned inaccurate.
|
||||
|
||||
Besides backtesting, the optimization of strategies from different levels is not standalone and can be affected by each other.
|
||||
For example, the best portfolio management strategy may change with the performance of order executions(e.g. a portfolio with higher turnover may becomes a better choice when we imporve the order execution strategies).
|
||||
To achieve the overall good performance , it is necessary to consider the interaction of strategies in different level.
|
||||
|
||||
Therefore, building a new framework for trading in multiple levels becomes necessary to solve the various problems mentioned above, for which we designed a nested decision execution framework that consider the interaction of strategies.
|
||||
|
||||
.. image:: ../_static/img/framework.svg
|
||||
|
||||
The design of the framework is shown in the figure above. At each layer consists of Trading Agent and Execution Env. The Trading Agent has its own data processing module (Information Extractor), forecasting module (Forecast Model) and decision generator (Decision Generator). The trading algorithm generates the corresponding decisions by the Decision Generator based on the forecast signals output by the Forecast Module, and the decisions generated by the trading algorithm are passed to the Execution Env, which returns the execution results. Here the frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intra-day trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The hierarchical order execution framework is user-defined in terms of hierarchy division and decision frequency, making it easy for users to explore the effects of combining different levels of trading algorithms and breaking down the barriers between different levels of trading algorithm optimization.
|
||||
In addition to the innovation in the framework, the hierarchical order execution framework also takes into account various details of the real backtesting environment, minimizing the differences with the final real environment as much as possible. At the same time, the framework is designed to unify the interface between online and offline (e.g. data pre-processing level supports using the same set of code to process both offline and online data) to reduce the cost of strategy go-live as much as possible.
|
||||
|
||||
Prepare Data
|
||||
===================
|
||||
.. _data:: ../../examples/highfreq/README.md
|
||||
The design of the framework is shown in the yellow part in the middle of the figure above. Each level consists of ``Trading Agent`` and ``Execution Env``. ``Trading Agent`` has its own data processing module (``Information Extractor``), forecasting module (``Forecast Model``) and decision generator (``Decision Generator``). The trading algorithm generates the decisions by the ``Decision Generator`` based on the forecast signals output by the ``Forecast Module``, and the decisions generated by the trading algorithm are passed to the ``Execution Env``, which returns the execution results.
|
||||
|
||||
The frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of trading algorithm.
|
||||
|
||||
Example
|
||||
===========================
|
||||
|
||||
Here is an example of highfreq execution.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import qlib
|
||||
# init qlib
|
||||
provider_uri_day = "~/.qlib/qlib_data/cn_data"
|
||||
provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
qlib.init(provider_uri=provider_uri_day, expression_cache=None, dataset_cache=None)
|
||||
|
||||
# data freq and backtest time
|
||||
freq = "1min"
|
||||
inst_list = D.list_instruments(D.instruments("all"), as_list=True)
|
||||
start_time = "2020-01-01"
|
||||
start_time = "2020-01-31"
|
||||
|
||||
When initializing qlib, if the default data is used, then both daily and minute frequency data need to be passed in.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# random order strategy config
|
||||
strategy_config = {
|
||||
"class": "RandomOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"trade_range": TradeRangeByTime("9:30", "15:00"),
|
||||
"sample_ratio": 1.0,
|
||||
"volume_ratio": 0.01,
|
||||
"market": market,
|
||||
},
|
||||
}
|
||||
|
||||
.. code-block:: python
|
||||
# backtest config
|
||||
backtest_config = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"account": 100000000,
|
||||
"benchmark": None,
|
||||
"exchange_kwargs": {
|
||||
"freq": freq,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"codes": market,
|
||||
},
|
||||
"pos_type": "InfPosition", # Position with infinitive position
|
||||
}
|
||||
|
||||
please refer to "../../qlib/backtest".
|
||||
|
||||
.. code-block:: python
|
||||
# excutor config
|
||||
executor_config = {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq,
|
||||
"generate_portfolio_metrics": True,
|
||||
"verbose": False,
|
||||
# "verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"track_data": True,
|
||||
"generate_portfolio_metrics": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
NestedExecutor represents not the innermost layer, the initialization parameters should contain inner_executor and inner_strategy. simulatorExecutor represents the current excutor is the innermost layer, the innermost strategy used here is the TWAP strategy, the framework currently also supports the VWAP strategy
|
||||
|
||||
.. code-block:: python
|
||||
# backtest
|
||||
portfolio_metrics_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
|
||||
|
||||
The metrics of backtest are included in the portfolio_metrics_dict and indicator_dict.
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
@@ -12,7 +12,9 @@ Introduction
|
||||
|
||||
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Portfolio Strategy`` can be used as an independent module also.
|
||||
|
||||
``Qlib`` provides several implemented portfolio strategies. Also, ``Qlib`` supports custom strategy, users can customize strategies according to their own needs.
|
||||
``Qlib`` provides several implemented portfolio strategies. Also, ``Qlib`` supports custom strategy, users can customize strategies according to their own requirements.
|
||||
|
||||
After users specifying the models(forecasting signals) and strategies, running backtest will help users to check the performance of a custom model(forecasting signals)/strategy.
|
||||
|
||||
Base Class & Interface
|
||||
======================
|
||||
@@ -82,38 +84,203 @@ TopkDropoutStrategy
|
||||
|
||||
Usage & Example
|
||||
====================
|
||||
``Portfolio Strategy`` can be specified in the ``Intraday Trading(Backtest)``, the example is as follows.
|
||||
|
||||
First, user can create a model to get trading signals(the variable name is ``pred_score`` in following cases).
|
||||
|
||||
Prediction Score
|
||||
-----------------
|
||||
|
||||
The `prediction score` is a pandas DataFrame. Its index is <datetime(pd.Timestamp), instrument(str)> and it must
|
||||
contains a `score` column.
|
||||
|
||||
A prediction sample is shown as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import backtest
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
|
||||
}
|
||||
# use default strategy
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
datetime instrument score
|
||||
2019-01-04 SH600000 -0.505488
|
||||
2019-01-04 SZ002531 -0.320391
|
||||
2019-01-04 SZ000999 0.583808
|
||||
2019-01-04 SZ300569 0.819628
|
||||
2019-01-04 SZ001696 -0.137140
|
||||
... ...
|
||||
2019-04-30 SZ000996 -1.027618
|
||||
2019-04-30 SH603127 0.225677
|
||||
2019-04-30 SH603126 0.462443
|
||||
2019-04-30 SH603133 -0.302460
|
||||
2019-04-30 SZ300760 -0.126383
|
||||
|
||||
# pred_score is the `prediction score` output by Model
|
||||
report_normal, positions_normal = backtest(
|
||||
pred_score, strategy=strategy, **BACKTEST_CONFIG
|
||||
)
|
||||
``Forecast Model`` module can make predictions, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
|
||||
Running backtest
|
||||
-----------------
|
||||
|
||||
- In most cases, users could backtest their portfolio management strategy with ``backtest_daily``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.utils.time import Freq
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.contrib.evaluate import backtest_daily
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
|
||||
# init qlib
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
# pred_score, pd.Series
|
||||
"signal": pred_score,
|
||||
}
|
||||
|
||||
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = backtest_daily(
|
||||
start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj
|
||||
)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
|
||||
)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
pprint(analysis_df)
|
||||
|
||||
|
||||
|
||||
- If users would like to control their strategies in a more detailed(e.g. users have a more advanced version of executor), user could follow this example.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.utils.time import Freq
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.backtest import backtest, executor
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
|
||||
# init qlib
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
FREQ = "day"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
# pred_score, pd.Series
|
||||
"signal": pred_score,
|
||||
}
|
||||
|
||||
EXECUTOR_CONFIG = {
|
||||
"time_per_step": "day",
|
||||
"generate_portfolio_metrics": True,
|
||||
}
|
||||
|
||||
backtest_config = {
|
||||
"start_time": "2017-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"account": 100000000,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"exchange_kwargs": {
|
||||
"freq": FREQ,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# strategy object
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
# executor object
|
||||
executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG)
|
||||
# backtest
|
||||
portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config)
|
||||
analysis_freq = "{0}{1}".format(*Freq.parse(FREQ))
|
||||
# backtest info
|
||||
report_normal, positions_normal = portfolio_metric_dict.get(analysis_freq)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
|
||||
)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
# log metrics
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
# print out results
|
||||
pprint(f"The following are analysis results of benchmark return({analysis_freq}).")
|
||||
pprint(risk_analysis(report_normal["bench"], freq=analysis_freq))
|
||||
pprint(f"The following are analysis results of the excess return without cost({analysis_freq}).")
|
||||
pprint(analysis["excess_return_without_cost"])
|
||||
pprint(f"The following are analysis results of the excess return with cost({analysis_freq}).")
|
||||
pprint(analysis["excess_return_with_cost"])
|
||||
|
||||
|
||||
Result
|
||||
------------------
|
||||
|
||||
The backtest results are in the following form:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
risk
|
||||
excess_return_without_cost mean 0.000605
|
||||
std 0.005481
|
||||
annualized_return 0.152373
|
||||
information_ratio 1.751319
|
||||
max_drawdown -0.059055
|
||||
excess_return_with_cost mean 0.000410
|
||||
std 0.005478
|
||||
annualized_return 0.103265
|
||||
information_ratio 1.187411
|
||||
max_drawdown -0.075024
|
||||
|
||||
|
||||
- `excess_return_without_cost`
|
||||
- `mean`
|
||||
Mean value of the `CAR` (cumulative abnormal return) without cost
|
||||
- `std`
|
||||
The `Standard Deviation` of `CAR` (cumulative abnormal return) without cost.
|
||||
- `annualized_return`
|
||||
The `Annualized Rate` of `CAR` (cumulative abnormal return) without cost.
|
||||
- `information_ratio`
|
||||
The `Information Ratio` without cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) without cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
|
||||
|
||||
- `excess_return_with_cost`
|
||||
- `mean`
|
||||
Mean value of the `CAR` (cumulative abnormal return) series with cost
|
||||
- `std`
|
||||
The `Standard Deviation` of `CAR` (cumulative abnormal return) series with cost.
|
||||
- `annualized_return`
|
||||
The `Annualized Rate` of `CAR` (cumulative abnormal return) with cost.
|
||||
- `information_ratio`
|
||||
The `Information Ratio` with cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) with cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
|
||||
|
||||
|
||||
Reference
|
||||
===================
|
||||
To know more about ``Portfolio Strategy``, please refer to `Strategy API <../reference/api.html#module-qlib.contrib.strategy.strategy>`_.
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
@@ -53,6 +53,9 @@ Below is a typical config file of ``qrun``.
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
@@ -240,6 +243,9 @@ The following script is the configuration of `backtest` and the `strategy` used
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
|
||||
@@ -38,8 +38,8 @@ Document Structure
|
||||
Workflow: Workflow Management <component/workflow.rst>
|
||||
Data Layer: Data Framework&Usage <component/data.rst>
|
||||
Forecast Model: Model Training & Prediction <component/model.rst>
|
||||
Strategy: Portfolio Management <component/strategy.rst>
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Portfolio Management and Backtest <component/strategy.rst>
|
||||
Nested Decision Execution: High-Frequency Trading <component/highfreq.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
|
||||
@@ -34,9 +34,14 @@ Name Description
|
||||
|
||||
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
|
||||
`Information Extractor` extracts data for models. `Forecast Model` focuses
|
||||
on producing all kinds of forecast signals (e.g. _alpha_, risk) for other
|
||||
modules. With these signals `Portfolio Generator` will generate the target
|
||||
portfolio and produce orders to be executed by `Order Executor`.
|
||||
on producing all kinds of forecast signals (e.g. *alpha*, risk) for other
|
||||
modules. With these signals `Decision Generator` will generate the target
|
||||
trading decisions(i.e. portfolio, orders) to be executed by `Execution Env`
|
||||
(i.e. the trading market). There may be multiple levels of `Trading Agent`
|
||||
and `Execution Env` (e.g. an *order executor trading agent and intraday
|
||||
order execution environment* could behave like an interday trading
|
||||
environment and nested in *daily portfolio management trading agent and
|
||||
interday trading environment* )
|
||||
|
||||
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
|
||||
system. `Analyser` module will provide users detailed analysis reports of
|
||||
|
||||
@@ -48,6 +48,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
- ``qlib.config.REG_CN``: China stock market.
|
||||
|
||||
Different modes will result in different trading limitations and costs.
|
||||
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/main/qlib/config.py#L239>`_. Users can set the key configurations manually if the existing region setting can't meet their requirements.
|
||||
- `redis_host`
|
||||
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
|
||||
The lock and cache mechanism relies on redis.
|
||||
|
||||
4
examples/benchmarks/ADARNN/README.md
Normal file
4
examples/benchmarks/ADARNN/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# AdaRNN
|
||||
* Code: [https://github.com/jindongwang/transferlearning/tree/master/code/deep/adarnn](https://github.com/jindongwang/transferlearning/tree/master/code/deep/adarnn)
|
||||
* Paper: [AdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/pdf/2108.04443.pdf).
|
||||
|
||||
4
examples/benchmarks/ADARNN/requirements.txt
Normal file
4
examples/benchmarks/ADARNN/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -0,0 +1,88 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: ADARNN
|
||||
module_path: qlib.contrib.model.pytorch_adarnn
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
3
examples/benchmarks/ADD/README.md
Normal file
3
examples/benchmarks/ADD/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# ADD
|
||||
* Paper: [ADD: Augmented Disentanglement Distillation Framework for Improving Stock Trend Forecasting](https://arxiv.org/abs/2012.06289).
|
||||
|
||||
4
examples/benchmarks/ADD/requirements.txt
Normal file
4
examples/benchmarks/ADD/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
94
examples/benchmarks/ADD/workflow_config_add_Alpha360.yaml
Normal file
94
examples/benchmarks/ADD/workflow_config_add_Alpha360.yaml
Normal file
@@ -0,0 +1,94 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: ADD
|
||||
module_path: qlib.contrib.model.pytorch_add
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.1
|
||||
dec_dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 5000
|
||||
metric: ic
|
||||
base_model: GRU
|
||||
gamma: 0.1
|
||||
gamma_clip: 0.2
|
||||
optimizer: adam
|
||||
mu: 0.2
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -86,4 +87,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -100,4 +101,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -35,8 +35,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -94,4 +95,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
2
examples/benchmarks/GRU/README.md
Normal file
2
examples/benchmarks/GRU/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
# Gated Recurrent Unit (GRU)
|
||||
* Paper: [Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation](https://aclanthology.org/D14-1179.pdf).
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
2
examples/benchmarks/LSTM/README.md
Normal file
2
examples/benchmarks/LSTM/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
# Long Short-Term Memory (LSTM)
|
||||
* Paper: [Long Short-Term Memory](https://direct.mit.edu/neco/article-abstract/9/8/1735/6109/Long-Short-Term-Memory?redirectedFrom=fulltext).
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -14,7 +14,7 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
|
||||
@@ -33,6 +33,9 @@ port_analysis_config: &port_analysis_config
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
@@ -80,4 +83,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -76,4 +77,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -31,18 +31,22 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
limit_threshold: 0.095
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
1
examples/benchmarks/Localformer/README.md
Normal file
1
examples/benchmarks/Localformer/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# Localformer
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
1
examples/benchmarks/MLP/README.md
Normal file
1
examples/benchmarks/MLP/README.md
Normal file
@@ -0,0 +1 @@
|
||||
# Multi-Layer Perceptron (MLP)
|
||||
@@ -41,8 +41,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -98,4 +99,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -85,4 +86,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -1,51 +1,68 @@
|
||||
# Benchmarks Performance
|
||||
This page lists a batch of methods designed for alpha seeking. Each method tries to give scores/predictions for all stocks each day(e.g. forecasting the future excess return of stocks). The scores/predictions of the models will be used as the mined alpha. Investing in stocks with higher scores is expected to yield more profit.
|
||||
|
||||
The alpha is evaluated in two ways.
|
||||
1. The correlation between the alpha and future return.
|
||||
1. Constructing portfolio based on the alpha and evaluating the final total return.
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
|
||||
The numbers shown below demonstrate the performance of the entire `workflow` of each model. We will update the `workflow` as well as models in the near future for better results.
|
||||
|
||||
<!--
|
||||
> If you need to reproduce the results below, please use the **v1** dataset: `python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d --region cn --version v1`
|
||||
>
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ
|
||||
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ -->
|
||||
|
||||
> NOTE:
|
||||
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
|
||||
|
||||
## Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| Linear | Alpha360 | 0.0150±0.00 | 0.1049±0.00| 0.0284±0.00 | 0.1970±0.00 | -0.0659±0.00 | -0.7072±0.00| -0.2955±0.00 |
|
||||
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0397±0.00 | 0.2878±0.00| 0.0470±0.00 | 0.3703±0.00 | 0.0342±0.00 | 0.4092±0.00| -0.1057±0.00 |
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha360 | 0.0400±0.00 | 0.3031±0.00| 0.0461±0.00 | 0.3862±0.00 | 0.0528±0.00 | 0.6307±0.00| -0.1113±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha360 | 0.0399±0.00 | 0.3075±0.00| 0.0492±0.00 | 0.4019±0.00 | 0.0323±0.00 | 0.4370±0.00| -0.0917±0.00 |
|
||||
| MLP | Alpha360 | 0.0285±0.00 | 0.1981±0.02| 0.0402±0.00 | 0.2993±0.02 | 0.0073±0.02 | 0.0880±0.22| -0.1446±0.03 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha360 | 0.0490±0.01 | 0.3787±0.05| 0.0581±0.00 | 0.4664±0.04 | 0.0726±0.02 | 0.9817±0.34| -0.0902±0.03 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
| TCTS (Xueqing Wu, et al.)| Alpha360 | 0.0485±0.00 | 0.3689±0.04| 0.0586±0.00 | 0.4669±0.02 | 0.0816±0.02 | 1.1572±0.30| -0.0689±0.02 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha360 | 0.0141±0.00 | 0.0917±0.02| 0.0331±0.00 | 0.2357±0.03 | -0.0259±0.03 | -0.3323±0.43| -0.1763±0.07 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha360 | 0.0408±0.00 | 0.2988±0.03| 0.0538±0.00 | 0.4105±0.02 | 0.0275±0.03 | 0.3464±0.37| -0.1182±0.03 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha360 | 0.0491±0.01 | 0.3868±0.06 | 0.0589±0.00 | 0.4802±0.04 | 0.0898±0.02 | 1.2490±0.32 | -0.0778±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|---|---|---|---|---|---|---|---|---|
|
||||
| Linear | Alpha158 | 0.0393±0.00 | 0.2980±0.00| 0.0475±0.00 | 0.3546±0.00 | 0.0795±0.00 | 1.0712±0.00| -0.1449±0.00 |
|
||||
| CatBoost (Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0503±0.00 | 0.3586±0.00| 0.0483±0.00 | 0.3667±0.00 | 0.1080±0.00 | 1.1561±0.00| -0.0787±0.00 |
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
| Transformer (Ashish Vaswani, et al.)| Alpha158 | 0.0274±0.00 | 0.2166±0.04| 0.0409±0.00 | 0.3342±0.04 | 0.0204±0.03 | 0.2888±0.40| -0.1216±0.04 |
|
||||
| Localformer (Juyong Jiang, et al.)| Alpha158 | 0.0355±0.00 | 0.2747±0.04| 0.0466±0.00 | 0.3762±0.03 | 0.0506±0.02 | 0.7447±0.34| -0.0875±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 (with selected 20 features)| 0.0409±0.00 | 0.3253±0.04 | 0.0488±0.00 | 0.4045±0.02 | 0.0673±0.02 | 1.0389±0.39 | -0.0830±0.02 |
|
||||
| TRA (Hengxu Lin, et al.)| Alpha158 | 0.0442±0.00 | 0.3426±0.03 | 0.0555±0.00 | 0.4395±0.03 | 0.0833±0.03 | 1.2064±0.36 | -0.0849±0.02 |
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 |
|
||||
| TabNet(Sercan O. Arik, et al.) | Alpha158 | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04 | 0.3676±0.54 | -0.1089±0.08 |
|
||||
| Transformer(Ashish Vaswani, et al.) | Alpha158 | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02 | 0.3970±0.26 | -0.1101±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02 | 0.5160±0.25 | -0.1017±0.02 |
|
||||
| LSTM(Sepp Hochreiter, et al.) | Alpha158(with selected 20 features) | 0.0318±0.00 | 0.2367±0.04 | 0.0435±0.00 | 0.3389±0.03 | 0.0381±0.03 | 0.5561±0.46 | -0.1207±0.04 |
|
||||
| Localformer(Juyong Jiang, et al.) | Alpha158 | 0.0356±0.00 | 0.2756±0.03 | 0.0468±0.00 | 0.3784±0.03 | 0.0438±0.02 | 0.6600±0.33 | -0.0952±0.02 |
|
||||
| SFM(Liheng Zhang, et al.) | Alpha158 | 0.0379±0.00 | 0.2959±0.04 | 0.0464±0.00 | 0.3825±0.04 | 0.0465±0.02 | 0.5672±0.29 | -0.1282±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158(with selected 20 features) | 0.0362±0.01 | 0.2789±0.06 | 0.0463±0.01 | 0.3661±0.05 | 0.0470±0.03 | 0.6992±0.47 | -0.1072±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158(with selected 20 features) | 0.0349±0.00 | 0.2511±0.01 | 0.0462±0.00 | 0.3564±0.01 | 0.0497±0.01 | 0.7338±0.19 | -0.0777±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha158(with selected 20 features) | 0.0404±0.00 | 0.3197±0.05 | 0.0490±0.00 | 0.4047±0.04 | 0.0649±0.02 | 1.0091±0.30 | -0.0860±0.02 |
|
||||
| Linear | Alpha158 | 0.0397±0.00 | 0.3000±0.00 | 0.0472±0.00 | 0.3531±0.00 | 0.0692±0.00 | 0.9209±0.00 | -0.1509±0.00 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha158 | 0.0440±0.00 | 0.3535±0.05 | 0.0540±0.00 | 0.4451±0.03 | 0.0718±0.02 | 1.0835±0.35 | -0.0760±0.02 |
|
||||
| CatBoost(Liudmila Prokhorenkova, et al.) | Alpha158 | 0.0481±0.00 | 0.3366±0.00 | 0.0454±0.00 | 0.3311±0.00 | 0.0765±0.00 | 0.8032±0.01 | -0.1092±0.00 |
|
||||
| XGBoost(Tianqi Chen, et al.) | Alpha158 | 0.0498±0.00 | 0.3779±0.00 | 0.0505±0.00 | 0.4131±0.00 | 0.0780±0.00 | 0.9070±0.00 | -0.1168±0.00 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158(with selected 20 features) | 0.0358±0.00 | 0.2160±0.03 | 0.0116±0.01 | 0.0720±0.03 | 0.0847±0.02 | 0.8131±0.19 | -0.1824±0.03 |
|
||||
| MLP | Alpha158 | 0.0376±0.00 | 0.2846±0.02 | 0.0429±0.00 | 0.3220±0.01 | 0.0895±0.02 | 1.1408±0.23 | -0.1103±0.02 |
|
||||
| LightGBM(Guolin Ke, et al.) | Alpha158 | 0.0448±0.00 | 0.3660±0.00 | 0.0469±0.00 | 0.3877±0.00 | 0.0901±0.00 | 1.0164±0.00 | -0.1038±0.00 |
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
|
||||
|
||||
|
||||
## Alpha360 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| Transformer(Ashish Vaswani, et al.) | Alpha360 | 0.0114±0.00 | 0.0716±0.03 | 0.0327±0.00 | 0.2248±0.02 | -0.0270±0.03 | -0.3378±0.37 | -0.1653±0.05 |
|
||||
| TabNet(Sercan O. Arik, et al.) | Alpha360 | 0.0099±0.00 | 0.0593±0.00 | 0.0290±0.00 | 0.1887±0.00 | -0.0369±0.00 | -0.3892±0.00 | -0.2145±0.00 |
|
||||
| MLP | Alpha360 | 0.0273±0.00 | 0.1870±0.02 | 0.0396±0.00 | 0.2910±0.02 | 0.0029±0.02 | 0.0274±0.23 | -0.1385±0.03 |
|
||||
| Localformer(Juyong Jiang, et al.) | Alpha360 | 0.0404±0.00 | 0.2932±0.04 | 0.0542±0.00 | 0.4110±0.03 | 0.0246±0.02 | 0.3211±0.21 | -0.1095±0.02 |
|
||||
| CatBoost((Liudmila Prokhorenkova, et al.) | Alpha360 | 0.0378±0.00 | 0.2714±0.00 | 0.0467±0.00 | 0.3659±0.00 | 0.0292±0.00 | 0.3781±0.00 | -0.0862±0.00 |
|
||||
| XGBoost(Tianqi Chen, et al.) | Alpha360 | 0.0394±0.00 | 0.2909±0.00 | 0.0448±0.00 | 0.3679±0.00 | 0.0344±0.00 | 0.4527±0.02 | -0.1004±0.00 |
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha360 | 0.0404±0.00 | 0.3023±0.00 | 0.0495±0.00 | 0.3898±0.00 | 0.0468±0.01 | 0.6302±0.20 | -0.0860±0.01 |
|
||||
| LightGBM(Guolin Ke, et al.) | Alpha360 | 0.0400±0.00 | 0.3037±0.00 | 0.0499±0.00 | 0.4042±0.00 | 0.0558±0.00 | 0.7632±0.00 | -0.0659±0.00 |
|
||||
| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 |
|
||||
| LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 |
|
||||
| ADD | Alpha360 | 0.0430±0.00 | 0.3188±0.04 | 0.0559±0.00 | 0.4301±0.03 | 0.0667±0.02 | 0.8992±0.34 | -0.0855±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 |
|
||||
| AdaRNN(Yuntao Du, et al.) | Alpha360 | 0.0464±0.01 | 0.3619±0.08 | 0.0539±0.01 | 0.4287±0.06 | 0.0753±0.03 | 1.0200±0.40 | -0.0936±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
|
||||
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
- The base model of TCTS is GRU.
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
4
examples/benchmarks/TCN/README.md
Normal file
4
examples/benchmarks/TCN/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# TCN
|
||||
* Code: [https://github.com/locuslab/TCN](https://github.com/locuslab/TCN)
|
||||
* Paper: [An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling](https://arxiv.org/abs/1803.01271).
|
||||
|
||||
4
examples/benchmarks/TCN/requirements.txt
Normal file
4
examples/benchmarks/TCN/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
100
examples/benchmarks/TCN/workflow_config_tcn_Alpha158.yaml
Executable file
100
examples/benchmarks/TCN/workflow_config_tcn_Alpha158.yaml
Executable file
@@ -0,0 +1,100 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCN
|
||||
module_path: qlib.contrib.model.pytorch_tcn_ts
|
||||
kwargs:
|
||||
d_feat: 20
|
||||
num_layers: 5
|
||||
n_chans: 32
|
||||
kernel_size: 7
|
||||
dropout: 0.5
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
batch_size: 2000
|
||||
metric: loss
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
90
examples/benchmarks/TCN/workflow_config_tcn_Alpha360.yaml
Normal file
90
examples/benchmarks/TCN/workflow_config_tcn_Alpha360.yaml
Normal file
@@ -0,0 +1,90 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCN
|
||||
module_path: qlib.contrib.model.pytorch_tcn
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
num_layers: 5
|
||||
n_chans: 128
|
||||
kernel_size: 3
|
||||
dropout: 0.5
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 2000
|
||||
metric: loss
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,52 +1,38 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. This work introduces a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in the current minibatch) and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
At step <img src="https://latex.codecogs.com/png.latex?s" title="s" />, with training data <img src="https://latex.codecogs.com/png.latex?x_s,y_s" title="x_s,y_s" />, the scheduler <img src="https://latex.codecogs.com/png.latex?\varphi" title="\varphi" /> chooses a suitable task <img src="https://latex.codecogs.com/png.latex?T_{i_s}" title="T_{i_s}" /> (green solid lines) to update the model <img src="https://latex.codecogs.com/png.latex?f" title="f" /> (blue solid lines). After <img src="https://latex.codecogs.com/png.latex?S" title="S" /> steps, we evaluate the model <img src="https://latex.codecogs.com/png.latex?f" title="f" /> on the validation set and update the scheduler <img src="https://latex.codecogs.com/png.latex?\varphi" title="\varphi" /> (green dashed lines).
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
Due to different data versions and different Qlib versions, the original data and data preprocessing methods of the experimental settings in the paper are different from those experimental settings in the existing Qlib version. Therefore, we provide two versions of the code according to the two kinds of settings, 1) the [code](https://github.com/lwwang1995/tcts) that can be used to reproduce the experimental results and 2) the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) in the current Qlib baseline.
|
||||
|
||||
#### Setting1
|
||||
* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
* The main tasks <img src="https://latex.codecogs.com/png.latex?T_k" title="T_k" /> refers to forecasting return of stock <img src="https://latex.codecogs.com/png.latex?i" title="i" /> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
<img src="https://latex.codecogs.com/png.image?\dpi{110}&space;r_{i}^{t,k}&space;=&space;\frac{price_i^{t+k}}{price_i^{t+k-1}}-1" title="r_{i}^{t,k} = \frac{price_i^{t+k}}{price_i^{t+k-1}}-1" />
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
* Temporally correlated task sets <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_k&space;=&space;\{T_1,&space;T_2,&space;...&space;,&space;T_k\}" title="\mathcal{T}_k = \{T_1, T_2, ... , T_k\}" />, in this paper, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_3" title="\mathcal{T}_3" />, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_5" title="\mathcal{T}_5" /> and <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_{10}" title="\mathcal{T}_{10}" /> are used in <img src="https://latex.codecogs.com/png.latex?T_1" title="T_1" />, <img src="https://latex.codecogs.com/png.latex?T_2" title="T_2" />, and <img src="https://latex.codecogs.com/png.latex?T_3" title="T_3" />.
|
||||
|
||||
#### Setting2
|
||||
* Dataset: We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020. We split the data into training (01/01/2008-12/31/2014), validation (01/01/2015-12/31/2016), and test sets (01/01/2017-08/01/2020) based on the transaction time.
|
||||
|
||||
* The main tasks <img src="https://latex.codecogs.com/png.latex?T_k" title="T_k" /> refers to forecasting return of stock <img src="https://latex.codecogs.com/png.latex?i" title="i" /> as following,
|
||||
<div align=center>
|
||||
<img src="https://latex.codecogs.com/png.image?\dpi{110}&space;r_{i}^{t,k}&space;=&space;\frac{price_i^{t+1+k}}{price_i^{t+1}}-1" title="r_{i}^{t,k} = \frac{price_i^{t+1+k}}{price_i^{t+1}}-1" />
|
||||
</div>
|
||||
|
||||
* In Qlib baseline, <img src="https://latex.codecogs.com/png.latex?\mathcal{T}_3" title="\mathcal{T}_3" />, is used in <img src="https://latex.codecogs.com/png.latex?T_1" title="T_1" />.
|
||||
|
||||
### Experimental Result
|
||||
You can find the experimental result of setting1 in the [paper](http://proceedings.mlr.press/v139/wu21e/wu21e.pdf) and the experimental result of setting2 in this [page](https://github.com/microsoft/qlib/tree/main/examples/benchmarks).
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 25 KiB |
@@ -22,16 +22,17 @@ data_handler_config: &data_handler_config
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -1) / $close - 1",
|
||||
"Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -2) - 1"]
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -1) - 1",
|
||||
"Ref($close, -4) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -53,9 +54,8 @@ task:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
dropout: 0.3
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
@@ -64,10 +64,10 @@ task:
|
||||
fore_optimizer: adam
|
||||
weight_optimizer: adam
|
||||
output_dim: 3
|
||||
fore_lr: 5e-4
|
||||
weight_lr: 5e-4
|
||||
fore_lr: 2e-3
|
||||
weight_lr: 2e-3
|
||||
steps: 3
|
||||
target_label: 1
|
||||
target_label: 0
|
||||
lowest_valid_performance: 0.993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
@@ -92,8 +92,7 @@ task:
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
label_col: 1
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
@@ -195,7 +195,8 @@ class Alpha158Formatter(GenericDataFormatter):
|
||||
|
||||
for col in column_names:
|
||||
if col not in {"forecast_time", "identifier"}:
|
||||
output[col] = self._target_scaler.inverse_transform(predictions[col])
|
||||
# Using [col] is for aligning with the format when fitting
|
||||
output[col] = self._target_scaler.inverse_transform(predictions[[col]])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -311,5 +311,11 @@ class TFTModel(ModelFT):
|
||||
# self.model.save(path)
|
||||
|
||||
# save qlib model wrapper
|
||||
self.model = None
|
||||
drop_attrs = ["model", "tf_graph", "sess", "data_formatter"]
|
||||
orig_attr = {}
|
||||
for attr in drop_attrs:
|
||||
orig_attr[attr] = getattr(self, attr)
|
||||
setattr(self, attr, None)
|
||||
super(TFTModel, self).to_pickle(path)
|
||||
for attr in drop_attrs:
|
||||
setattr(self, attr, orig_attr[attr])
|
||||
|
||||
@@ -16,8 +16,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -38,7 +38,7 @@ class TRAModel(Model):
|
||||
model_init_state=None,
|
||||
lamb=0.0,
|
||||
rho=0.99,
|
||||
seed=0,
|
||||
seed=None,
|
||||
logdir=None,
|
||||
eval_train=True,
|
||||
eval_test=False,
|
||||
|
||||
@@ -57,8 +57,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
3
examples/benchmarks/TabNet/README.md
Normal file
3
examples/benchmarks/TabNet/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# TabNet
|
||||
* Code: [https://github.com/dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet)
|
||||
* Paper: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/pdf/1908.07442.pdf).
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -50,6 +51,7 @@ task:
|
||||
kwargs:
|
||||
d_feat: 158
|
||||
pretrain: True
|
||||
seed: 993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
@@ -50,6 +51,7 @@ task:
|
||||
kwargs:
|
||||
d_feat: 360
|
||||
pretrain: True
|
||||
seed: 993
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
3
examples/benchmarks/Transformer/README.md
Normal file
3
examples/benchmarks/Transformer/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Transformer
|
||||
* Code: [https://github.com/tensorflow/tensor2tensor](https://github.com/tensorflow/tensor2tensor)
|
||||
* Paper: [Attention is All you Need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf).
|
||||
@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
|
||||
2
examples/data_demo/README.md
Normal file
2
examples/data_demo/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
# Introduction
|
||||
The examples in this folder try to demonstrate some common usage of data-related modules of Qlib
|
||||
53
examples/data_demo/data_cache_demo.py
Normal file
53
examples/data_demo/data_cache_demo.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
import subprocess
|
||||
import yaml
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
from qlib import init
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
# For general purpose, we use relative path
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
if __name__ == "__main__":
|
||||
init()
|
||||
|
||||
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
||||
|
||||
# 1) show original time
|
||||
with TimeInspector.logt("The original time without handler cache:"):
|
||||
subprocess.run(f"qrun {config_path}", shell=True)
|
||||
|
||||
# 2) dump handler
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
||||
pprint(hd_conf)
|
||||
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
||||
hd_path = DIRNAME / "handler.pkl"
|
||||
hd.to_pickle(hd_path, dump_all=True)
|
||||
|
||||
# 3) create new task with handler cache
|
||||
new_task_config = deepcopy(task_config)
|
||||
new_task_config["task"]["dataset"]["kwargs"]["handler"] = f"file://{hd_path}"
|
||||
new_task_config["sys"] = {"path": [str(config_path.parent.resolve())]}
|
||||
new_task_path = DIRNAME / "new_task.yaml"
|
||||
print("The location of the new task", new_task_path)
|
||||
|
||||
# save new task
|
||||
with new_task_path.open("w") as f:
|
||||
yaml.safe_dump(new_task_config, f, indent=4, sort_keys=False)
|
||||
|
||||
# 4) train model with new task
|
||||
with TimeInspector.logt("The time for task with handler cache:"):
|
||||
subprocess.run(f"qrun {new_task_path}", shell=True)
|
||||
59
examples/data_demo/data_mem_resuse_demo.py
Normal file
59
examples/data_demo/data_mem_resuse_demo.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
import subprocess
|
||||
|
||||
import yaml
|
||||
|
||||
from qlib import init
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
# For general purpose, we use relative path
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
if __name__ == "__main__":
|
||||
init()
|
||||
|
||||
repeat = 2
|
||||
exp_name = "data_mem_reuse_demo"
|
||||
|
||||
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
|
||||
# 1) without using processed data in memory
|
||||
with TimeInspector.logt("The original time without reusing processed data in memory:"):
|
||||
for i in range(repeat):
|
||||
task_train(task_config["task"], experiment_name=exp_name)
|
||||
|
||||
# 2) prepare processed data in memory.
|
||||
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
||||
pprint(hd_conf)
|
||||
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
||||
|
||||
# 3) with reusing processed data in memory
|
||||
new_task = deepcopy(task_config["task"])
|
||||
new_task["dataset"]["kwargs"]["handler"] = hd
|
||||
print(new_task)
|
||||
|
||||
with TimeInspector.logt("The time with reusing processed data in memory:"):
|
||||
# this will save the time to reload and process data from disk(in `DataHandlerLP`)
|
||||
# It still takes a lot of time in the backtest phase
|
||||
for i in range(repeat):
|
||||
task_train(new_task, experiment_name=exp_name)
|
||||
|
||||
# 4) User can change other parts exclude processed data in memory(handler)
|
||||
new_task = deepcopy(task_config["task"])
|
||||
new_task["dataset"]["kwargs"]["segments"]["train"] = ("20100101", "20131231")
|
||||
with TimeInspector.logt("The time with reusing processed data in memory:"):
|
||||
task_train(new_task, experiment_name=exp_name)
|
||||
@@ -30,6 +30,7 @@ Run the example by running the following command:
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
| LightGBM | Alpha158 | 0.0349±0.00 | 0.3805±0.00| 0.0435±0.00 | 0.4724±0.00 | 0.5111±0.00 | 0.5428±0.00 | 0.000074±0.00 | 0.2677±0.00 |
|
||||
|
||||
@@ -59,7 +59,7 @@ task:
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs:
|
||||
kwargs: {}
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
@@ -1,9 +1,105 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The expect result of `backtest` is following in current version
|
||||
|
||||
'The following are analysis results of benchmark return(1day).'
|
||||
risk
|
||||
mean 0.000651
|
||||
std 0.012472
|
||||
annualized_return 0.154967
|
||||
information_ratio 0.805422
|
||||
max_drawdown -0.160445
|
||||
'The following are analysis results of the excess return without cost(1day).'
|
||||
risk
|
||||
mean 0.001258
|
||||
std 0.007575
|
||||
annualized_return 0.299303
|
||||
information_ratio 2.561219
|
||||
max_drawdown -0.068386
|
||||
'The following are analysis results of the excess return with cost(1day).'
|
||||
risk
|
||||
mean 0.001110
|
||||
std 0.007575
|
||||
annualized_return 0.264280
|
||||
information_ratio 2.261392
|
||||
max_drawdown -0.071842
|
||||
[1706497:MainThread](2021-12-07 14:08:30,263) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_30minute.
|
||||
pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of benchmark return(30minute).'
|
||||
risk
|
||||
mean 0.000078
|
||||
std 0.003646
|
||||
annualized_return 0.148787
|
||||
information_ratio 0.935252
|
||||
max_drawdown -0.142830
|
||||
('The following are analysis results of the excess return without '
|
||||
'cost(30minute).')
|
||||
risk
|
||||
mean 0.000174
|
||||
std 0.003343
|
||||
annualized_return 0.331867
|
||||
information_ratio 2.275019
|
||||
max_drawdown -0.074752
|
||||
'The following are analysis results of the excess return with cost(30minute).'
|
||||
risk
|
||||
mean 0.000155
|
||||
std 0.003343
|
||||
annualized_return 0.294536
|
||||
information_ratio 2.018860
|
||||
max_drawdown -0.075579
|
||||
[1706497:MainThread](2021-12-07 14:08:30,277) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_5minute.p
|
||||
kl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of benchmark return(5minute).'
|
||||
risk
|
||||
mean 0.000015
|
||||
std 0.001460
|
||||
annualized_return 0.172170
|
||||
information_ratio 1.103439
|
||||
max_drawdown -0.144807
|
||||
'The following are analysis results of the excess return without cost(5minute).'
|
||||
risk
|
||||
mean 0.000028
|
||||
std 0.001412
|
||||
annualized_return 0.319771
|
||||
information_ratio 2.119563
|
||||
max_drawdown -0.077426
|
||||
'The following are analysis results of the excess return with cost(5minute).'
|
||||
risk
|
||||
mean 0.000025
|
||||
std 0.001412
|
||||
annualized_return 0.281536
|
||||
information_ratio 1.866091
|
||||
max_drawdown -0.078194
|
||||
[1706497:MainThread](2021-12-07 14:08:30,287) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_1day
|
||||
.pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of indicators(1day).'
|
||||
value
|
||||
ffr 0.945821
|
||||
pa 0.000324
|
||||
pos 0.542882
|
||||
[1706497:MainThread](2021-12-07 14:08:30,293) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_30mi
|
||||
nute.pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of indicators(30minute).'
|
||||
value
|
||||
ffr 0.982910
|
||||
pa 0.000037
|
||||
pos 0.500806
|
||||
[1706497:MainThread](2021-12-07 14:08:30,302) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_5min
|
||||
ute.pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of indicators(5minute).'
|
||||
value
|
||||
ffr 0.991017
|
||||
pa 0.000000
|
||||
pos 0.000000
|
||||
[1706497:MainThread](2021-12-07 14:08:30,627) INFO - qlib.timer - [log.py:113] - Time cost: 0.014s | waiting `async_log` Done
|
||||
"""
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
import qlib
|
||||
import fire
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.data import D
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
@@ -14,7 +110,6 @@ from qlib.backtest import collect_data
|
||||
|
||||
|
||||
class NestedDecisionExecutionWorkflow:
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
data_handler_config = {
|
||||
@@ -151,10 +246,9 @@ class NestedDecisionExecutionWorkflow:
|
||||
self._train_model(model, dataset)
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
@@ -163,13 +257,10 @@ class NestedDecisionExecutionWorkflow:
|
||||
self.port_analysis_config["backtest"]["benchmark"] = self.benchmark
|
||||
|
||||
with R.start(experiment_name="backtest"):
|
||||
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(
|
||||
recorder,
|
||||
self.port_analysis_config,
|
||||
risk_analysis_freq=["day", "30min", "5min"],
|
||||
indicator_analysis_freq=["day", "30min", "5min"],
|
||||
indicator_analysis_method="value_weighted",
|
||||
)
|
||||
par.generate()
|
||||
@@ -189,10 +280,9 @@ class NestedDecisionExecutionWorkflow:
|
||||
backtest_config["benchmark"] = self.benchmark
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
@@ -201,6 +291,101 @@ class NestedDecisionExecutionWorkflow:
|
||||
for trade_decision in data_generator:
|
||||
print(trade_decision)
|
||||
|
||||
# the code below are for checking, users don't have to care about it
|
||||
# The tests can be categorized into 2 types
|
||||
# 1) comparing same backtest
|
||||
# - Basic test idea: the shared accumulated value are equal in multiple levels
|
||||
# - Aligning the profit calculation between multiple levels and single levels.
|
||||
# 2) comparing different backtest
|
||||
# - Basic test idea:
|
||||
# - the daily backtest will be similar as multi-level(the data quality makes this gap samller)
|
||||
|
||||
def check_diff_freq(self):
|
||||
self._init_qlib()
|
||||
exp = R.get_exp(experiment_name="backtest")
|
||||
rec = next(iter(exp.list_recorders().values())) # assuming this will get the latest recorder
|
||||
for check_key in "account", "total_turnover", "total_cost":
|
||||
check_key = "total_cost"
|
||||
|
||||
acc_dict = {}
|
||||
for freq in ["30minute", "5minute", "1day"]:
|
||||
acc_dict[freq] = rec.load_object(f"portfolio_analysis/report_normal_{freq}.pkl")[check_key]
|
||||
acc_df = pd.DataFrame(acc_dict)
|
||||
acc_resam = acc_df.resample("1d").last().dropna()
|
||||
assert (acc_resam["30minute"] == acc_resam["1day"]).all()
|
||||
|
||||
def backtest_only_daily(self):
|
||||
"""
|
||||
This backtest is used for comparing the nested execution and single layer execution
|
||||
Due to the low quality daily-level and miniute-level data, they are hardly comparable.
|
||||
So it is used for detecting serious bugs which make the results different greatly.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
[1724971:MainThread](2021-12-07 16:24:31,156) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_1day.pkl'
|
||||
has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of benchmark return(1day).'
|
||||
risk
|
||||
mean 0.000651
|
||||
std 0.012472
|
||||
annualized_return 0.154967
|
||||
information_ratio 0.805422
|
||||
max_drawdown -0.160445
|
||||
'The following are analysis results of the excess return without cost(1day).'
|
||||
risk
|
||||
mean 0.001375
|
||||
std 0.006103
|
||||
annualized_return 0.327204
|
||||
information_ratio 3.475016
|
||||
max_drawdown -0.024927
|
||||
'The following are analysis results of the excess return with cost(1day).'
|
||||
risk
|
||||
mean 0.001184
|
||||
std 0.006091
|
||||
annualized_return 0.281801
|
||||
information_ratio 2.998749
|
||||
max_drawdown -0.029568
|
||||
[1724971:MainThread](2021-12-07 16:24:31,170) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_1day.
|
||||
pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of indicators(1day).'
|
||||
value
|
||||
ffr 1.0
|
||||
pa 0.0
|
||||
pos 0.0
|
||||
[1724971:MainThread](2021-12-07 16:24:31,188) INFO - qlib.timer - [log.py:113] - Time cost: 0.007s | waiting `async_log` Done
|
||||
|
||||
"""
|
||||
self._init_qlib()
|
||||
model = init_instance_by_config(self.task["model"])
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
self._train_model(model, dataset)
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}
|
||||
pa_conf = deepcopy(self.port_analysis_config)
|
||||
pa_conf["strategy"] = strategy_config
|
||||
pa_conf["executor"] = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"generate_portfolio_metrics": True,
|
||||
"verbose": True,
|
||||
},
|
||||
}
|
||||
pa_conf["backtest"]["benchmark"] = self.benchmark
|
||||
|
||||
with R.start(experiment_name="backtest"):
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(recorder, pa_conf)
|
||||
par.generate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(NestedDecisionExecutionWorkflow)
|
||||
|
||||
@@ -151,6 +151,9 @@ def get_all_results(folders) -> dict:
|
||||
if recorders[recorder_id].status == "FINISHED":
|
||||
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
if "1day.excess_return_with_cost.annualized_return" not in metrics:
|
||||
print(f"{recorder_id} is skipped due to incomplete result")
|
||||
continue
|
||||
result["annualized_return_with_cost"].append(metrics["1day.excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["1day.excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["1day.excess_return_with_cost.max_drawdown"])
|
||||
@@ -200,174 +203,183 @@ def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir):
|
||||
return temp_path
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
exp_folder_name: str = "run_all_model_records",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
|
||||
for multiple times, and this will be fixed in the future development.
|
||||
class ModelRunner:
|
||||
def _init_qlib(self, exp_folder_name):
|
||||
# init qlib
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
qlib.init(
|
||||
exp_manager={
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
times : int
|
||||
determines how many times the model should be running.
|
||||
models : str or list
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
exp_folder_name: str
|
||||
the name of the experiment folder
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
wait when errors raised when executing commands
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(
|
||||
self,
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
exp_folder_name: str = "run_all_model_records",
|
||||
wait_before_rm_env: bool = False,
|
||||
wait_when_err: bool = False,
|
||||
):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed. Besides, this script doesn't support parallel running the same model
|
||||
for multiple times, and this will be fixed in the future development.
|
||||
|
||||
Usage:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
Parameters:
|
||||
-----------
|
||||
times : int
|
||||
determines how many times the model should be running.
|
||||
models : str or list
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path
|
||||
exp_folder_name: str
|
||||
the name of the experiment folder
|
||||
wait_before_rm_env : bool
|
||||
wait before remove environment.
|
||||
wait_when_err : bool
|
||||
wait when errors raised when executing commands
|
||||
|
||||
.. code-block:: bash
|
||||
Usage:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
|
||||
# Case 1 - run all models multiple times
|
||||
python run_all_model.py 3
|
||||
.. code-block:: bash
|
||||
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py 3 mlp
|
||||
# Case 1 - run all models multiple times
|
||||
python run_all_model.py run 3
|
||||
|
||||
# Case 3 - run specific models multiple times with specific dataset
|
||||
python run_all_model.py 3 mlp Alpha158
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py run 3 mlp
|
||||
|
||||
# Case 4 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py 3 [mlp,tft,lstm] --exclude=True
|
||||
# Case 3 - run specific models multiple times with specific dataset
|
||||
python run_all_model.py run 3 mlp Alpha158
|
||||
|
||||
# Case 5 - run specific models for one time
|
||||
python run_all_model.py --models=[mlp,lightgbm]
|
||||
# Case 4 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py run 3 [mlp,tft,lstm] --exclude=True
|
||||
|
||||
# Case 6 - run other models except those are given as arguments for one time
|
||||
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
|
||||
# Case 5 - run specific models for one time
|
||||
python run_all_model.py run --models=[mlp,lightgbm]
|
||||
|
||||
"""
|
||||
# init qlib
|
||||
GetData().qlib_data(exists_skip=True)
|
||||
qlib.init(
|
||||
exp_manager={
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + str(Path(os.getcwd()).resolve() / exp_folder_name),
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
)
|
||||
# Case 6 - run other models except those are given as arguments for one time
|
||||
python run_all_model.py run --models=[mlp,tft,sfm] --exclude=True
|
||||
|
||||
# get all folders
|
||||
folders = get_all_folders(models, exclude)
|
||||
# init error messages:
|
||||
errors = dict()
|
||||
# run all the model for iterations
|
||||
for fn in folders:
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset)
|
||||
if yaml_path is None:
|
||||
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
|
||||
continue
|
||||
sys.stderr.write("\n")
|
||||
# create env by anaconda
|
||||
temp_dir, env_path, python_path, conda_activate = create_env()
|
||||
"""
|
||||
self._init_qlib(exp_folder_name)
|
||||
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
with open(req_path) as f:
|
||||
content = f.read()
|
||||
if "torch" in content:
|
||||
# automatically install pytorch according to nvidia's version
|
||||
execute(
|
||||
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
|
||||
) # for automatically installing torch according to the nvidia driver
|
||||
execute(
|
||||
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
else:
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
# get all folders
|
||||
folders = get_all_folders(models, exclude)
|
||||
# init error messages:
|
||||
errors = dict()
|
||||
# run all the model for iterations
|
||||
for fn in folders:
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset)
|
||||
if yaml_path is None:
|
||||
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
|
||||
continue
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
_errs.update({i: errs})
|
||||
errors[fn] = _errs
|
||||
# create env by anaconda
|
||||
temp_dir, env_path, python_path, conda_activate = create_env()
|
||||
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
with open(req_path) as f:
|
||||
content = f.read()
|
||||
if "torch" in content:
|
||||
# automatically install pytorch according to nvidia's version
|
||||
execute(
|
||||
f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err
|
||||
) # for automatically installing torch according to the nvidia driver
|
||||
execute(
|
||||
f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
else:
|
||||
execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
# read yaml, remove seed kwargs of model, and then save file in the temp_dir
|
||||
yaml_path = gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir)
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e {qlib_uri}",
|
||||
wait_when_err=wait_when_err,
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'bin' / 'qrun'} {yaml_path} {fn} {exp_folder_name}",
|
||||
wait_when_err=wait_when_err,
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
_errs.update({i: errs})
|
||||
errors[fn] = _errs
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
if wait_before_rm_env:
|
||||
input("Press Enter to Continue")
|
||||
shutil.rmtree(env_path)
|
||||
# print errors
|
||||
sys.stderr.write(f"Here are some of the errors of the models...\n")
|
||||
pprint(errors)
|
||||
self._collect_results(exp_folder_name, dataset)
|
||||
|
||||
def _collect_results(self, exp_folder_name, dataset):
|
||||
folders = get_all_folders(exp_folder_name, dataset)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
if len(results) > 0:
|
||||
# calculating the mean and std
|
||||
sys.stderr.write(f"Calculating the mean and std of results...\n")
|
||||
results = cal_mean_std(results)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results, dataset)
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
if wait_before_rm_env:
|
||||
input("Press Enter to Continue")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
if len(results) > 0:
|
||||
# calculating the mean and std
|
||||
sys.stderr.write(f"Calculating the mean and std of results...\n")
|
||||
results = cal_mean_std(results)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results, dataset)
|
||||
sys.stderr.write("\n")
|
||||
# print errors
|
||||
sys.stderr.write(f"Here are some of the errors of the models...\n")
|
||||
pprint(errors)
|
||||
sys.stderr.write("\n")
|
||||
# move results folder
|
||||
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
|
||||
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
|
||||
# move results folder
|
||||
shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}")
|
||||
shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(run) # run all the model
|
||||
fire.Fire(ModelRunner) # run all the model
|
||||
|
||||
@@ -204,7 +204,7 @@
|
||||
" },\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"class\": \"TopkDropoutStrategy\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"model\": model,\n",
|
||||
" \"dataset\": dataset,\n",
|
||||
|
||||
@@ -5,7 +5,7 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
@@ -31,10 +31,9 @@ if __name__ == "__main__":
|
||||
},
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
@@ -71,6 +70,10 @@ if __name__ == "__main__":
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# Signal Analysis
|
||||
sar = SigAnaRecord(recorder)
|
||||
sar.generate()
|
||||
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
par = PortAnaRecord(recorder, port_analysis_config, "day")
|
||||
|
||||
@@ -6,6 +6,7 @@ _version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is
|
||||
__version__ = _version_path.read_text(encoding="utf-8").strip()
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
import yaml
|
||||
import logging
|
||||
import platform
|
||||
@@ -151,14 +152,17 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
:param conf_path: A path to the qlib config in yml format
|
||||
"""
|
||||
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
if conf_path is None:
|
||||
config = {}
|
||||
else:
|
||||
with open(conf_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
def get_project_path(config_name="config.yaml", cur_path: Union[Path, str, None] = None) -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
@@ -187,6 +191,7 @@ def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
"""
|
||||
if cur_path is None:
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
cur_path = Path(cur_path)
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
@@ -202,6 +207,40 @@ def auto_init(**kwargs):
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
- Skip initialization if already initialized
|
||||
|
||||
:**kwargs: it may contain following parameters
|
||||
cur_path: the start path to find the project path
|
||||
|
||||
Here are two examples of the configuration
|
||||
|
||||
Example 1)
|
||||
If you want create a new project-specific config based on a shared configure, you can use `conf_type: ref`
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
conf_type: ref
|
||||
qlib_cfg: '<shared_yaml_config_path>' # this could be null reference no config from other files
|
||||
# following configs in `qlib_cfg_update` is project=specific
|
||||
qlib_cfg_update:
|
||||
exp_manager:
|
||||
class: "MLflowExpManager"
|
||||
module_path: "qlib.workflow.expm"
|
||||
kwargs:
|
||||
uri: "file://<your mlflow experiment path>"
|
||||
default_exp_name: "Experiment"
|
||||
|
||||
Example 2)
|
||||
If you wan to create simple a stand alone config, you can use following config(a.k.a `conf_type: origin`)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
exp_manager:
|
||||
class: "MLflowExpManager"
|
||||
module_path: "qlib.workflow.expm"
|
||||
kwargs:
|
||||
uri: "file://<your mlflow experiment path>"
|
||||
default_exp_name: "Experiment"
|
||||
|
||||
"""
|
||||
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
|
||||
|
||||
@@ -210,6 +249,7 @@ def auto_init(**kwargs):
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
logger = get_module_logger("Initialization")
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
@@ -223,8 +263,14 @@ def auto_init(**kwargs):
|
||||
# - There is a shared configure file and you don't want to edit it inplace.
|
||||
# - The shared configure may be updated later and you don't want to copy it.
|
||||
# - You have some customized config.
|
||||
qlib_conf_path = conf["qlib_cfg"]
|
||||
qlib_conf_update = conf.get("qlib_cfg_update")
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
|
||||
logger = get_module_logger("Initialization")
|
||||
qlib_conf_path = conf.get("qlib_cfg", None)
|
||||
|
||||
# merge the arguments
|
||||
qlib_conf_update = conf.get("qlib_cfg_update", {})
|
||||
for k, v in kwargs.items():
|
||||
if k in qlib_conf_update:
|
||||
logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'")
|
||||
qlib_conf_update.update(kwargs)
|
||||
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update)
|
||||
logger.info(f"Auto load project config: {conf_pp}")
|
||||
|
||||
@@ -50,11 +50,12 @@ def get_exchange(
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost.
|
||||
open transaction cost. It is a ratio. The cost is proportional to your order's deal amount.
|
||||
close_cost : float
|
||||
close transaction cost.
|
||||
close transaction cost. It is a ratio. The cost is proportional to your order's deal amount.
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
|
||||
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
|
||||
trade_unit : int
|
||||
Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
@@ -185,8 +186,10 @@ def get_strategy_executor(
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
|
||||
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
trade_strategy.reset_common_infra(common_infra)
|
||||
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor)
|
||||
trade_executor.reset_common_infra(common_infra)
|
||||
|
||||
return trade_strategy, trade_executor
|
||||
|
||||
|
||||
@@ -29,7 +29,10 @@ rtn & earning in the Account
|
||||
|
||||
|
||||
class AccumulatedInfo:
|
||||
"""accumulated trading info, including accumulated return/cost/turnover"""
|
||||
"""
|
||||
accumulated trading info, including accumulated return/cost/turnover
|
||||
AccumulatedInfo should be shared accross different levels
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -62,6 +65,11 @@ class AccumulatedInfo:
|
||||
|
||||
|
||||
class Account:
|
||||
"""
|
||||
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in qlib/backtest/executor.py:NestedExecutor
|
||||
Different level of executor has different Account object when calculating metrics. But the position object is shared cross all the Account object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_cash: float = 1e9,
|
||||
@@ -95,6 +103,8 @@ class Account:
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
|
||||
# 1) the following variables are shared by multiple layers
|
||||
# - you will see a shallow copy instead of deepcopy in the NestedExecutor;
|
||||
self.init_cash = init_cash
|
||||
self.current_position: BasePosition = init_instance_by_config(
|
||||
{
|
||||
@@ -106,6 +116,9 @@ class Account:
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
)
|
||||
self.accum_info = AccumulatedInfo()
|
||||
|
||||
# 2) following variables are not shared between layers
|
||||
self.portfolio_metrics = None
|
||||
self.hist_positions = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||
@@ -119,7 +132,8 @@ class Account:
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
self.accum_info = AccumulatedInfo()
|
||||
# NOTE:
|
||||
# `accum_info` and `current_position` are shared here
|
||||
self.portfolio_metrics = PortfolioMetrics(freq, benchmark_config)
|
||||
self.hist_positions = {}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ class Exchange:
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
impact_cost=0.0,
|
||||
extra_quote=None,
|
||||
quote_cls=NumpyQuote,
|
||||
**kwargs,
|
||||
@@ -95,6 +96,7 @@ class Exchange:
|
||||
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
|
||||
distinguish `not set` and `disable trade_unit`
|
||||
:param min_cost: min cost, default 5
|
||||
:param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1.
|
||||
:param extra_quote: pandas, dataframe consists of
|
||||
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
|
||||
The limit indicates that the etf is tradable on a specific day.
|
||||
@@ -164,9 +166,12 @@ class Exchange:
|
||||
all_fields = list(all_fields | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
|
||||
self.open_cost = open_cost
|
||||
self.close_cost = close_cost
|
||||
self.min_cost = min_cost
|
||||
self.impact_cost = impact_cost
|
||||
|
||||
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
|
||||
self.volume_threshold = volume_threshold
|
||||
self.extra_quote = extra_quote
|
||||
@@ -226,7 +231,7 @@ class Exchange:
|
||||
self.extra_quote["limit_buy"] = False
|
||||
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
|
||||
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
|
||||
self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0)
|
||||
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
LT_TP_EXP = "(exp)" # Tuple[str, str]
|
||||
LT_FLT = "float" # float
|
||||
@@ -685,12 +690,14 @@ class Exchange:
|
||||
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
|
||||
)
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash):
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
|
||||
"""return the real order amount after cash limit for buying.
|
||||
Parameters
|
||||
----------
|
||||
trade_price : float
|
||||
position : cash
|
||||
cost_ratio : float
|
||||
|
||||
Return
|
||||
----------
|
||||
float
|
||||
@@ -699,10 +706,10 @@ class Exchange:
|
||||
max_trade_amount = 0
|
||||
if cash >= self.min_cost:
|
||||
# critical_price means the stock transaction price when the service fee is equal to min_cost.
|
||||
critical_price = self.min_cost / self.open_cost + self.min_cost
|
||||
critical_price = self.min_cost / cost_ratio + self.min_cost
|
||||
if cash >= critical_price:
|
||||
# the service fee is equal to open_cost * trade_amount
|
||||
max_trade_amount = cash / (1 + self.open_cost) / trade_price
|
||||
# the service fee is equal to cost_ratio * trade_amount
|
||||
max_trade_amount = cash / (1 + cost_ratio) / trade_price
|
||||
else:
|
||||
# the service fee is equal to min_cost
|
||||
max_trade_amount = (cash - self.min_cost) / trade_price
|
||||
@@ -718,6 +725,7 @@ class Exchange:
|
||||
:return: trade_price, trade_val, trade_cost
|
||||
"""
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
order.deal_amount = order.amount # set to full amount and clip it step by step
|
||||
# Clipping amount first
|
||||
@@ -726,8 +734,16 @@ class Exchange:
|
||||
# - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
|
||||
self._clip_amount_by_volume(order, dealt_order_amount)
|
||||
|
||||
# TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps
|
||||
trade_val = order.deal_amount * trade_price
|
||||
if not total_trade_val or np.isnan(total_trade_val):
|
||||
# TODO: assert trade_val == 0, f"trade_val != 0, total_trade_val: {total_trade_val}; order info: {order}"
|
||||
adj_cost_ratio = self.impact_cost
|
||||
else:
|
||||
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
|
||||
|
||||
if order.direction == Order.SELL:
|
||||
cost_ratio = self.close_cost
|
||||
cost_ratio = self.close_cost + adj_cost_ratio
|
||||
# sell
|
||||
# if we don't know current position, we choose to sell all
|
||||
# Otherwise, we clip the amount based on current position
|
||||
@@ -750,14 +766,18 @@ class Exchange:
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
|
||||
elif order.direction == Order.BUY:
|
||||
cost_ratio = self.open_cost
|
||||
cost_ratio = self.open_cost + adj_cost_ratio
|
||||
# buy
|
||||
if position is not None:
|
||||
cash = position.get_cash()
|
||||
trade_val = order.deal_amount * trade_price
|
||||
if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
|
||||
if cash < max(trade_val * cost_ratio, self.min_cost):
|
||||
# cash cannot cover cost
|
||||
order.deal_amount = 0
|
||||
self.logger.debug(f"Order clipped due to cost higher than cash: {order}")
|
||||
elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
|
||||
# The money is not enough
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(max_buy_amount, order.deal_amount), order.factor
|
||||
)
|
||||
|
||||
@@ -130,7 +130,7 @@ class BaseExecutor:
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
# NOTE: there is a trick in the code.
|
||||
# copy is used instead of deepcopy. So positions are shared
|
||||
# shallow copy is used instead of deepcopy. So positions are shared
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||
|
||||
|
||||
@@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote):
|
||||
if is_single_value(start_time, end_time, self.freq, self.region):
|
||||
# this is a very special case.
|
||||
# skip aggregating function to speed-up the query calculation
|
||||
|
||||
# FIXME:
|
||||
# it will go to the else logic when it comes to the
|
||||
# 1) the day before holiday when daily trading
|
||||
# 2) the last minute of the day when intraday trading
|
||||
try:
|
||||
return self.data[stock_id].loc[start_time, field]
|
||||
except KeyError:
|
||||
|
||||
@@ -223,6 +223,12 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
||||
|
||||
def __str__(self):
|
||||
return self.__dict__.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
return self.__dict__.__repr__()
|
||||
|
||||
|
||||
class Position(BasePosition):
|
||||
"""Position
|
||||
@@ -345,15 +351,19 @@ class Position(BasePosition):
|
||||
if stock_id not in self.position:
|
||||
raise KeyError("{} not in current position".format(stock_id))
|
||||
else:
|
||||
# decrease the amount of stock
|
||||
self.position[stock_id]["amount"] -= trade_amount
|
||||
# check if to delete
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
)
|
||||
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
|
||||
if np.isclose(self.position[stock_id]["amount"], trade_amount):
|
||||
# Selling all the stocks
|
||||
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount
|
||||
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
|
||||
self._del_stock(stock_id)
|
||||
else:
|
||||
# decrease the amount of stock
|
||||
self.position[stock_id]["amount"] -= trade_amount
|
||||
# check if to delete
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
)
|
||||
|
||||
new_cash = trade_val - cost
|
||||
if self._settle_type == self.ST_CASH:
|
||||
|
||||
102
qlib/backtest/signal.py
Normal file
102
qlib/backtest/signal.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.utils import init_instance_by_config
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import Dataset
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..utils.resam import resam_ts_data
|
||||
import pandas as pd
|
||||
import abc
|
||||
|
||||
|
||||
class Signal(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Some trading strategy make decisions based on other prediction signals
|
||||
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
|
||||
|
||||
This interface is tries to provide unified interface for those different sources
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
|
||||
"""
|
||||
get the signal at the end of the decision step(from `start_time` to `end_time`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[pd.Series, pd.DataFrame, None]:
|
||||
returns None if no signal in the specific day
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SignalWCache(Signal):
|
||||
"""
|
||||
Signal With pandas with based Cache
|
||||
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
|
||||
"""
|
||||
|
||||
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : Union[pd.Series, pd.DataFrame]
|
||||
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
|
||||
|
||||
instrument datetime
|
||||
SH600000 2008-01-02 0.079704
|
||||
2008-01-03 0.120125
|
||||
2008-01-04 0.878860
|
||||
2008-01-07 0.505539
|
||||
2008-01-08 0.395004
|
||||
"""
|
||||
self.signal_cache = convert_index_format(signal, level="datetime")
|
||||
|
||||
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
|
||||
# the frequency of the signal may not algin with the decision frequency of strategy
|
||||
# so resampling from the data is necessary
|
||||
# the latest signal leverage more recent data and therefore is used in trading.
|
||||
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
|
||||
return signal
|
||||
|
||||
|
||||
class ModelSignal(SignalWCache):
|
||||
def __init__(self, model: BaseModel, dataset: Dataset):
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
pred_scores = self.model.predict(dataset)
|
||||
if isinstance(pred_scores, pd.DataFrame):
|
||||
pred_scores = pred_scores.iloc[:, 0]
|
||||
super().__init__(pred_scores)
|
||||
|
||||
def _update_model(self):
|
||||
"""
|
||||
When using online data, update model in each bar as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
"""
|
||||
# TODO: this method is not included in the framework and could be refactor later
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
|
||||
def create_signal_from(
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
|
||||
) -> Signal:
|
||||
"""
|
||||
create signal from diverse information
|
||||
This method will choose the right method to create a signal based on `obj`
|
||||
Please refer to the code below.
|
||||
"""
|
||||
if isinstance(obj, Signal):
|
||||
return obj
|
||||
elif isinstance(obj, (tuple, list)):
|
||||
return ModelSignal(*obj)
|
||||
elif isinstance(obj, (dict, str)):
|
||||
return init_instance_by_config(obj)
|
||||
elif isinstance(obj, (pd.DataFrame, pd.Series)):
|
||||
return SignalWCache(signal=obj)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of signal is not supported")
|
||||
@@ -222,7 +222,7 @@ class CommonInfrastructure(BaseInfrastructure):
|
||||
|
||||
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
"""level instrastructure is created by executor, and then shared to strategies on the same level"""
|
||||
"""level infrastructure is created by executor, and then shared to strategies on the same level"""
|
||||
|
||||
def get_support_infra(self):
|
||||
"""
|
||||
|
||||
@@ -73,6 +73,9 @@ class Config:
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
|
||||
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||
PROTOCOL_VERSION = 4
|
||||
|
||||
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
|
||||
|
||||
DISK_DATASET_CACHE = "DiskDatasetCache"
|
||||
@@ -107,6 +110,8 @@ _default_config = {
|
||||
# for simple dataset cache
|
||||
"local_cache_path": None,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# pickle.dump protocol version
|
||||
"dump_protocol_version": PROTOCOL_VERSION,
|
||||
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
|
||||
"maxtasksperchild": None,
|
||||
# If joblib_backend is None, use loky
|
||||
@@ -239,8 +244,8 @@ HIGH_FREQ_CONFIG = {
|
||||
_default_region_config = {
|
||||
REG_CN: {
|
||||
"trade_unit": 100,
|
||||
"limit_threshold": 0.099,
|
||||
"deal_price": "vwap",
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
},
|
||||
REG_US: {
|
||||
"trade_unit": 1,
|
||||
@@ -265,6 +270,20 @@ class QlibConfig(Config):
|
||||
self.provider_uri = provider_uri
|
||||
self.mount_path = mount_path
|
||||
|
||||
@staticmethod
|
||||
def format_provider_uri(provider_uri: Union[str, dict, Path]) -> dict:
|
||||
if provider_uri is None:
|
||||
raise ValueError("provider_uri cannot be None")
|
||||
if isinstance(provider_uri, (str, dict, Path)):
|
||||
if not isinstance(provider_uri, dict):
|
||||
provider_uri = {QlibConfig.DEFAULT_FREQ: provider_uri}
|
||||
else:
|
||||
raise TypeError(f"provider_uri does not support {type(provider_uri)}")
|
||||
for freq, _uri in provider_uri.items():
|
||||
if QlibConfig.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
provider_uri[freq] = str(Path(_uri).expanduser().resolve())
|
||||
return provider_uri
|
||||
|
||||
@staticmethod
|
||||
def get_uri_type(uri: Union[str, Path]):
|
||||
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
|
||||
@@ -311,11 +330,7 @@ class QlibConfig(Config):
|
||||
def resolve_path(self):
|
||||
# resolve path
|
||||
_mount_path = self["mount_path"]
|
||||
_provider_uri = self["provider_uri"]
|
||||
if _provider_uri is None:
|
||||
raise ValueError("provider_uri cannot be None")
|
||||
if not isinstance(_provider_uri, dict):
|
||||
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
|
||||
_provider_uri = self.DataPathManager.format_provider_uri(self["provider_uri"])
|
||||
if not isinstance(_mount_path, dict):
|
||||
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
|
||||
|
||||
@@ -324,10 +339,7 @@ class QlibConfig(Config):
|
||||
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
|
||||
|
||||
# resolve
|
||||
for _freq, _uri in _provider_uri.items():
|
||||
# provider_uri
|
||||
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
|
||||
for _freq in _provider_uri.keys():
|
||||
# mount_path
|
||||
_mount_path[_freq] = (
|
||||
_mount_path[_freq]
|
||||
@@ -337,20 +349,6 @@ class QlibConfig(Config):
|
||||
self["provider_uri"] = _provider_uri
|
||||
self["mount_path"] = _mount_path
|
||||
|
||||
def get_uri_type(self):
|
||||
path = self["provider_uri"]
|
||||
if isinstance(path, Path):
|
||||
path = str(path)
|
||||
is_win = re.match("^[a-zA-Z]:.*", path) is not None # such as 'C:\\data', 'D:'
|
||||
is_nfs_or_win = (
|
||||
re.match("^[^/]+:.+", path) is not None
|
||||
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def set(self, default_conf: str = "client", **kwargs):
|
||||
"""
|
||||
configure qlib based on the input parameters
|
||||
|
||||
0
qlib/contrib/data/utils/__init__.py
Normal file
0
qlib/contrib/data/utils/__init__.py
Normal file
183
qlib/contrib/data/utils/sepdf.py
Normal file
183
qlib/contrib/data/utils/sepdf.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import pandas as pd
|
||||
from typing import Dict, Iterable
|
||||
|
||||
|
||||
def align_index(df_dict, join):
|
||||
res = {}
|
||||
for k, df in df_dict.items():
|
||||
if join is not None and k != join:
|
||||
df = df.reindex(df_dict[join].index)
|
||||
res[k] = df
|
||||
return res
|
||||
|
||||
|
||||
# Mocking the pd.DataFrame class
|
||||
class SepDataFrame:
|
||||
"""
|
||||
(Sep)erate DataFrame
|
||||
We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter).
|
||||
However, they are usally be used seperately at last.
|
||||
This will result in extra cost for concating and spliting data(reshaping and copying data in the memory is very expensive)
|
||||
|
||||
SepDataFrame tries to act like a DataFrame whose column with multiindex
|
||||
"""
|
||||
|
||||
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
|
||||
"""
|
||||
initialize the data based on the dataframe dictionary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df_dict : Dict[str, pd.DataFrame]
|
||||
dataframe dictionary
|
||||
join : str
|
||||
how to join the data
|
||||
It will reindex the dataframe based on the join key.
|
||||
If join is None, the reindex step will be skipped
|
||||
|
||||
skip_align :
|
||||
for some cases, we can improve performance by skipping aligning index
|
||||
"""
|
||||
self.join = join
|
||||
|
||||
if skip_align:
|
||||
self._df_dict = df_dict
|
||||
else:
|
||||
self._df_dict = align_index(df_dict, join)
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
return SDFLoc(self, join=self.join)
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
return self._df_dict[self.join].index
|
||||
|
||||
def apply_each(self, method: str, skip_align=True, *args, **kwargs):
|
||||
"""
|
||||
Assumptions:
|
||||
- inplace methods will return None
|
||||
"""
|
||||
inplace = False
|
||||
df_dict = {}
|
||||
for k, df in self._df_dict.items():
|
||||
df_dict[k] = getattr(df, method)(*args, **kwargs)
|
||||
if df_dict[k] is None:
|
||||
inplace = True
|
||||
if not inplace:
|
||||
return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align)
|
||||
|
||||
def sort_index(self, *args, **kwargs):
|
||||
return self.apply_each("sort_index", True, *args, **kwargs)
|
||||
|
||||
def copy(self, *args, **kwargs):
|
||||
return self.apply_each("copy", True, *args, **kwargs)
|
||||
|
||||
def _update_join(self):
|
||||
if self.join not in self:
|
||||
self.join = next(iter(self._df_dict.keys()))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._df_dict[item]
|
||||
|
||||
def __setitem__(self, item: str, df: pd.DataFrame):
|
||||
# TODO: consider the join behavior
|
||||
self._df_dict[item] = df
|
||||
|
||||
def __delitem__(self, item: str):
|
||||
del self._df_dict[item]
|
||||
self._update_join()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._df_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self._df_dict[self.join])
|
||||
|
||||
def droplevel(self, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `droplevel` method")
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
dfs = []
|
||||
for k, df in self._df_dict.items():
|
||||
df = df.head(0)
|
||||
df.columns = pd.MultiIndex.from_product([[k], df.columns])
|
||||
dfs.append(df)
|
||||
return pd.concat(dfs, axis=1).columns
|
||||
|
||||
# Useless methods
|
||||
@staticmethod
|
||||
def merge(df_dict: Dict[str, pd.DataFrame], join: str):
|
||||
all_df = df_dict[join]
|
||||
for k, df in df_dict.items():
|
||||
if k != join:
|
||||
all_df = all_df.join(df)
|
||||
return all_df
|
||||
|
||||
|
||||
class SDFLoc:
|
||||
"""Mock Class"""
|
||||
|
||||
def __init__(self, sdf: SepDataFrame, join):
|
||||
self._sdf = sdf
|
||||
self.axis = None
|
||||
self.join = join
|
||||
|
||||
def __call__(self, axis):
|
||||
self.axis = axis
|
||||
return self
|
||||
|
||||
def __getitem__(self, args):
|
||||
if self.axis == 1:
|
||||
if isinstance(args, str):
|
||||
return self._sdf[args]
|
||||
elif isinstance(args, (tuple, list)):
|
||||
new_df_dict = {k: self._sdf[k] for k in args}
|
||||
return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
elif self.axis == 0:
|
||||
return SepDataFrame(
|
||||
{k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True
|
||||
)
|
||||
else:
|
||||
df = self._sdf
|
||||
if isinstance(args, tuple):
|
||||
ax0, *ax1 = args
|
||||
if len(ax1) == 0:
|
||||
ax1 = None
|
||||
if ax1 is not None:
|
||||
df = df.loc(axis=1)[ax1]
|
||||
if ax0 is not None:
|
||||
df = df.loc(axis=0)[ax0]
|
||||
return df
|
||||
else:
|
||||
return df.loc(axis=0)[args]
|
||||
|
||||
|
||||
# Patch pandas DataFrame
|
||||
# Tricking isinstance to accept SepDataFrame as its subclass
|
||||
import builtins
|
||||
|
||||
|
||||
def _isinstance(instance, cls):
|
||||
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
|
||||
if isinstance(cls, Iterable):
|
||||
for c in cls:
|
||||
if c is pd.DataFrame:
|
||||
return True
|
||||
elif cls is pd.DataFrame:
|
||||
return True
|
||||
return isinstance_orig(instance, cls) # pylint: disable=E0602
|
||||
|
||||
|
||||
builtins.isinstance_orig = builtins.isinstance
|
||||
builtins.isinstance = _isinstance
|
||||
|
||||
if __name__ == "__main__":
|
||||
sdf = SepDataFrame({}, join=None)
|
||||
print(isinstance(sdf, (pd.DataFrame,)))
|
||||
print(isinstance(sdf, pd.DataFrame))
|
||||
@@ -3,15 +3,18 @@
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from logging import warn
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Union
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..backtest import get_exchange, backtest as backtest_func
|
||||
from ..utils import get_date_range
|
||||
from ..utils.resam import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..backtest import get_exchange, position, backtest as backtest_func, executor as _executor
|
||||
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
@@ -117,84 +120,129 @@ def indicator_analysis(df, method="mean"):
|
||||
|
||||
|
||||
# This is the API for compatibility for legacy code
|
||||
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs):
|
||||
"""This function will help you set a reasonable Exchange and provide default value for strategy
|
||||
def backtest_daily(
|
||||
start_time: Union[str, pd.Timestamp],
|
||||
end_time: Union[str, pd.Timestamp],
|
||||
strategy: Union[str, dict, BaseStrategy],
|
||||
executor: Union[str, dict, _executor.BaseExecutor] = None,
|
||||
account: Union[float, int, position.Position] = 1e8,
|
||||
benchmark: str = "SH000300",
|
||||
exchange_kwargs: dict = None,
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
"""initialize the strategy and executor, then executor the backtest of daily frequency
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : Union[str, pd.Timestamp]
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : Union[str, pd.Timestamp]
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
strategy : Union[str, dict, BaseStrategy]
|
||||
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
|
||||
|
||||
- **backtest workflow related or commmon arguments**
|
||||
E.g.
|
||||
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column.
|
||||
account : float
|
||||
init account value.
|
||||
shift : int
|
||||
whether to shift prediction by one day.
|
||||
benchmark : str
|
||||
benchmark code, default is SH000905 CSI 500.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
.. code-block:: python
|
||||
# dict
|
||||
strategy = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}
|
||||
# BaseStrategy
|
||||
pred_score = pd.read_pickle("score.pkl")["score"]
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
"signal": pred_score,
|
||||
}
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
# str example.
|
||||
# 1) specify a pickle object
|
||||
# - path like 'file:///<path to pickle file>/obj.pkl'
|
||||
# 2) specify a class name
|
||||
# - "ClassName": getattr(module, "ClassName")() will be used.
|
||||
# 3) specify module path with class name
|
||||
# - "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
|
||||
|
||||
- **strategy related arguments**
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
executor : Union[str, dict, BaseExecutor]
|
||||
for initializing the outermost executor.
|
||||
benchmark: str
|
||||
the benchmark for reporting.
|
||||
account : Union[float, int, Position]
|
||||
information for describing how to creating the account
|
||||
For `float` or `int`:
|
||||
Using Account with only initial cash
|
||||
For `Position`:
|
||||
Using Account with a Position
|
||||
exchange_kwargs : dict
|
||||
the kwargs for initializing Exchange
|
||||
E.g.
|
||||
|
||||
sell_limit = margin
|
||||
.. code-block:: python
|
||||
|
||||
- else:
|
||||
exchange_kwargs = {
|
||||
"freq": freq,
|
||||
"limit_threshold": None, # limit_threshold is None, using C.limit_threshold
|
||||
"deal_price": None, # deal_price is None, using C.deal_price
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
pos_type : str
|
||||
the type of Position.
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
- **exchange related arguments**
|
||||
|
||||
exchange: Exchange()
|
||||
pass the exchange for speeding up.
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost. The default value is 0.002(0.2%).
|
||||
close_cost : float
|
||||
close transaction cost. The default value is 0.002(0.2%).
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
|
||||
- **executor related arguments**
|
||||
|
||||
executor : BaseExecutor()
|
||||
executor used in backtest.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
Returns
|
||||
-------
|
||||
report_normal: pd.DataFrame
|
||||
backtest report
|
||||
positions_normal: pd.DataFrame
|
||||
backtest positions
|
||||
|
||||
"""
|
||||
warnings.warn("this function is deprecated, please use backtest function in qlib.backtest", DeprecationWarning)
|
||||
report_dict = backtest_func(
|
||||
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
|
||||
freq = "day"
|
||||
if executor is None:
|
||||
executor_config = {
|
||||
"time_per_step": freq,
|
||||
"generate_portfolio_metrics": True,
|
||||
}
|
||||
executor = _executor.SimulatorExecutor(**executor_config)
|
||||
_exchange_kwargs = {
|
||||
"freq": freq,
|
||||
"limit_threshold": None,
|
||||
"deal_price": None,
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
if exchange_kwargs is not None:
|
||||
_exchange_kwargs.update(exchange_kwargs)
|
||||
|
||||
portfolio_metric_dict, indicator_dict = backtest_func(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
strategy=strategy,
|
||||
executor=executor,
|
||||
account=account,
|
||||
benchmark=benchmark,
|
||||
exchange_kwargs=_exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
return report_dict.get("report_df"), report_dict.get("positions")
|
||||
analysis_freq = "{0}{1}".format(*Freq.parse(freq))
|
||||
|
||||
report_normal, positions_normal = portfolio_metric_dict.get(analysis_freq)
|
||||
|
||||
return report_normal, positions_normal
|
||||
|
||||
|
||||
def long_short_backtest(
|
||||
@@ -327,7 +375,12 @@ def t_run():
|
||||
pred["datetime"] = pd.to_datetime(pred["datetime"])
|
||||
pred = pred.set_index([pred.columns[0], pred.columns[1]])
|
||||
pred = pred.iloc[:9000]
|
||||
report_df, positions = backtest(pred=pred)
|
||||
strategy_config = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
"signal": pred,
|
||||
}
|
||||
report_df, positions = backtest_daily(start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_config)
|
||||
print(report_df.head())
|
||||
print(positions.keys())
|
||||
print(positions[list(positions.keys())[0]])
|
||||
|
||||
@@ -30,8 +30,10 @@ try:
|
||||
from .pytorch_nn import DNNModelPytorch
|
||||
from .pytorch_tabnet import TabnetModel
|
||||
from .pytorch_sfm import SFM_Model
|
||||
from .pytorch_tcn import TCN
|
||||
from .pytorch_add import ADD
|
||||
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
|
||||
@@ -38,6 +38,8 @@ class CatBoostModel(Model, FeatureInt):
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
|
||||
@@ -64,6 +64,8 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
# initialize the sample weights
|
||||
N, F = x_train.shape
|
||||
|
||||
@@ -14,17 +14,20 @@ from ...model.interpret.base import LightGBMFInt
|
||||
class LGBModel(ModelFT, LightGBMFInt):
|
||||
"""LightGBM Model"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
def __init__(self, loss="mse", early_stopping_rounds=50, **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self.params = {"objective": loss, "verbosity": -1}
|
||||
self.params.update(kwargs)
|
||||
self.early_stopping_rounds = early_stopping_rounds
|
||||
self.model = None
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
@@ -42,7 +45,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
early_stopping_rounds=None,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
@@ -54,7 +57,9 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
early_stopping_rounds=(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
),
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
@@ -83,6 +88,8 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
|
||||
@@ -82,6 +82,8 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
if df_train.empty or df_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_train["feature"], df_valid["label"]
|
||||
|
||||
@@ -51,6 +51,8 @@ class LinearModel(Model):
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if df_train.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
X, y = df_train["feature"].values, np.squeeze(df_train["label"].values)
|
||||
|
||||
if self.estimator in [self.OLS, self.RIDGE, self.LASSO]:
|
||||
|
||||
789
qlib/contrib/model/pytorch_adarnn.py
Normal file
789
qlib/contrib/model/pytorch_adarnn.py
Normal file
@@ -0,0 +1,789 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
import os
|
||||
from pdb import set_trace
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import copy
|
||||
from typing import Text, Union
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.autograd import Function
|
||||
from qlib.contrib.model.pytorch_utils import count_parameters
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import get_or_create_path
|
||||
|
||||
|
||||
class ADARNN(Model):
|
||||
"""ADARNN Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
pre_epoch=40,
|
||||
dw=0.5,
|
||||
loss_type="cosine",
|
||||
len_seq=60,
|
||||
len_win=0,
|
||||
lr=0.001,
|
||||
metric="mse",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
n_splits=2,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADARNN")
|
||||
self.logger.info("ADARNN pytorch version...")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.pre_epoch = pre_epoch
|
||||
self.dw = dw
|
||||
self.loss_type = loss_type
|
||||
self.len_seq = len_seq
|
||||
self.len_win = len_win
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_splits = n_splits
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"ADARNN parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
n_hiddens = [hidden_size for _ in range(num_layers)]
|
||||
self.model = AdaRNN(
|
||||
use_bottleneck=False,
|
||||
bottleneck_width=64,
|
||||
n_input=d_feat,
|
||||
n_hiddens=n_hiddens,
|
||||
n_output=1,
|
||||
dropout=dropout,
|
||||
model_type="AdaRNN",
|
||||
len_seq=len_seq,
|
||||
trans_loss=loss_type,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.cuda()
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
len_loader = len(loader)
|
||||
for data_all in zip(*train_loader_list):
|
||||
# for data_all in zip(*train_loader_list):
|
||||
self.train_optimizer.zero_grad()
|
||||
list_feat = []
|
||||
list_label = []
|
||||
for data in data_all:
|
||||
# feature :[36, 24, 6]
|
||||
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
|
||||
list_feat.append(feature)
|
||||
list_label.append(label_reg)
|
||||
flag = False
|
||||
index = get_index(len(data_all) - 1)
|
||||
for temp_index in index:
|
||||
s1 = temp_index[0]
|
||||
s2 = temp_index[1]
|
||||
if list_feat[s1].shape[0] != list_feat[s2].shape[0]:
|
||||
flag = True
|
||||
break
|
||||
if flag:
|
||||
continue
|
||||
|
||||
total_loss = torch.zeros(1).cuda()
|
||||
for i in range(len(index)):
|
||||
feature_s = list_feat[index[i][0]]
|
||||
feature_t = list_feat[index[i][1]]
|
||||
label_reg_s = list_label[index[i][0]]
|
||||
label_reg_t = list_label[index[i][1]]
|
||||
feature_all = torch.cat((feature_s, feature_t), 0)
|
||||
|
||||
if epoch < self.pre_epoch:
|
||||
pred_all, loss_transfer, out_weight_list = self.model.forward_pre_train(
|
||||
feature_all, len_win=self.len_win
|
||||
)
|
||||
else:
|
||||
pred_all, loss_transfer, dist, weight_mat = self.model.forward_Boosting(feature_all, weight_mat)
|
||||
dist_mat = dist_mat + dist
|
||||
pred_s = pred_all[0 : feature_s.size(0)]
|
||||
pred_t = pred_all[feature_s.size(0) :]
|
||||
|
||||
loss_s = criterion(pred_s, label_reg_s)
|
||||
loss_t = criterion(pred_t, label_reg_t)
|
||||
|
||||
total_loss = total_loss + loss_s + loss_t + self.dw * loss_transfer
|
||||
self.train_optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
if epoch >= self.pre_epoch:
|
||||
if epoch > self.pre_epoch:
|
||||
weight_mat = self.model.update_weight_Boosting(weight_mat, dist_old, dist_mat)
|
||||
return weight_mat, dist_mat
|
||||
else:
|
||||
weight_mat = self.transform_type(out_weight_list)
|
||||
return weight_mat, None
|
||||
|
||||
def calc_all_metrics(self, pred):
|
||||
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
||||
res = {}
|
||||
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
||||
rank_ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score, method="spearman"))
|
||||
res["ic"] = ic.mean()
|
||||
res["icir"] = ic.mean() / ic.std()
|
||||
res["ric"] = rank_ic.mean()
|
||||
res["ricir"] = rank_ic.mean() / rank_ic.std()
|
||||
res["mse"] = -(pred["label"] - pred["score"]).mean()
|
||||
res["loss"] = res["mse"]
|
||||
return res
|
||||
|
||||
def test_epoch(self, df):
|
||||
self.model.eval()
|
||||
preds = self.infer(df["feature"])
|
||||
label = df["label"].squeeze()
|
||||
preds = pd.DataFrame({"label": label, "score": preds}, index=df.index)
|
||||
metrics = self.calc_all_metrics(preds)
|
||||
return metrics
|
||||
|
||||
def log_metrics(self, mode, metrics):
|
||||
metrics = ["{}/{}: {:.6f}".format(k, mode, v) for k, v in metrics.items()]
|
||||
metrics = ", ".join(metrics)
|
||||
self.logger.info(metrics)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
# splits = ['2011-06-30']
|
||||
days = df_train.index.get_level_values(level=0).unique()
|
||||
train_splits = np.array_split(days, self.n_splits)
|
||||
train_splits = [df_train[s[0] : s[-1]] for s in train_splits]
|
||||
train_loader_list = [get_stock_loader(df, self.batch_size) for df in train_splits]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
weight_mat, dist_mat = None, None
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
weight_mat, dist_mat = self.train_AdaRNN(train_loader_list, step, dist_mat, weight_mat)
|
||||
self.logger.info("evaluating...")
|
||||
train_metrics = self.test_epoch(df_train)
|
||||
valid_metrics = self.test_epoch(df_valid)
|
||||
self.log_metrics("train: ", train_metrics)
|
||||
self.log_metrics("valid: ", valid_metrics)
|
||||
|
||||
valid_score = valid_metrics[self.metric]
|
||||
train_score = train_metrics[self.metric]
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(valid_score)
|
||||
if valid_score > best_score:
|
||||
best_score = valid_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
return best_score
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return self.infer(x_test)
|
||||
|
||||
def infer(self, x_test):
|
||||
index = x_test.index
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
x_values = x_values.reshape(sample_num, self.d_feat, -1).transpose(0, 2, 1)
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model.predict(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
def transform_type(self, init_weight):
|
||||
weight = torch.ones(self.num_layers, self.len_seq).cuda()
|
||||
for i in range(self.num_layers):
|
||||
for j in range(self.len_seq):
|
||||
weight[i, j] = init_weight[i][j].item()
|
||||
return weight
|
||||
|
||||
|
||||
class data_loader(Dataset):
|
||||
def __init__(self, df):
|
||||
self.df_feature = df["feature"]
|
||||
self.df_label_reg = df["label"]
|
||||
self.df_index = df.index
|
||||
self.df_feature = torch.tensor(
|
||||
self.df_feature.values.reshape(-1, 6, 60).transpose(0, 2, 1), dtype=torch.float32
|
||||
)
|
||||
self.df_label_reg = torch.tensor(self.df_label_reg.values.reshape(-1), dtype=torch.float32)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample, label_reg = self.df_feature[index], self.df_label_reg[index]
|
||||
return sample, label_reg
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df_feature)
|
||||
|
||||
|
||||
def get_stock_loader(df, batch_size, shuffle=True):
|
||||
train_loader = DataLoader(data_loader(df), batch_size=batch_size, shuffle=shuffle)
|
||||
return train_loader
|
||||
|
||||
|
||||
def get_index(num_domain=2):
|
||||
index = []
|
||||
for i in range(num_domain):
|
||||
for j in range(i + 1, num_domain + 1):
|
||||
index.append((i, j))
|
||||
return index
|
||||
|
||||
|
||||
class AdaRNN(nn.Module):
|
||||
"""
|
||||
model_type: 'Boosting', 'AdaRNN'
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_bottleneck=False,
|
||||
bottleneck_width=256,
|
||||
n_input=128,
|
||||
n_hiddens=[64, 64],
|
||||
n_output=6,
|
||||
dropout=0.0,
|
||||
len_seq=9,
|
||||
model_type="AdaRNN",
|
||||
trans_loss="mmd",
|
||||
):
|
||||
super(AdaRNN, self).__init__()
|
||||
self.use_bottleneck = use_bottleneck
|
||||
self.n_input = n_input
|
||||
self.num_layers = len(n_hiddens)
|
||||
self.hiddens = n_hiddens
|
||||
self.n_output = n_output
|
||||
self.model_type = model_type
|
||||
self.trans_loss = trans_loss
|
||||
self.len_seq = len_seq
|
||||
in_size = self.n_input
|
||||
|
||||
features = nn.ModuleList()
|
||||
for hidden in n_hiddens:
|
||||
rnn = nn.GRU(input_size=in_size, num_layers=1, hidden_size=hidden, batch_first=True, dropout=dropout)
|
||||
features.append(rnn)
|
||||
in_size = hidden
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
if use_bottleneck == True: # finance
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(n_hiddens[-1], bottleneck_width),
|
||||
nn.Linear(bottleneck_width, bottleneck_width),
|
||||
nn.BatchNorm1d(bottleneck_width),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(),
|
||||
)
|
||||
self.bottleneck[0].weight.data.normal_(0, 0.005)
|
||||
self.bottleneck[0].bias.data.fill_(0.1)
|
||||
self.bottleneck[1].weight.data.normal_(0, 0.005)
|
||||
self.bottleneck[1].bias.data.fill_(0.1)
|
||||
self.fc = nn.Linear(bottleneck_width, n_output)
|
||||
torch.nn.init.xavier_normal_(self.fc.weight)
|
||||
else:
|
||||
self.fc_out = nn.Linear(n_hiddens[-1], self.n_output)
|
||||
|
||||
if self.model_type == "AdaRNN":
|
||||
gate = nn.ModuleList()
|
||||
for i in range(len(n_hiddens)):
|
||||
gate_weight = nn.Linear(len_seq * self.hiddens[i] * 2, len_seq)
|
||||
gate.append(gate_weight)
|
||||
self.gate = gate
|
||||
|
||||
bnlst = nn.ModuleList()
|
||||
for i in range(len(n_hiddens)):
|
||||
bnlst.append(nn.BatchNorm1d(len_seq))
|
||||
self.bn_lst = bnlst
|
||||
self.softmax = torch.nn.Softmax(dim=0)
|
||||
self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
for i in range(len(self.hiddens)):
|
||||
self.gate[i].weight.data.normal_(0, 0.05)
|
||||
self.gate[i].bias.data.fill_(0.0)
|
||||
|
||||
def forward_pre_train(self, x, len_win=0):
|
||||
out = self.gru_features(x)
|
||||
fea = out[0] # [2N,L,H]
|
||||
if self.use_bottleneck == True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
fc_out = self.fc_out(fea[:, -1, :]).squeeze() # [N,]
|
||||
|
||||
out_list_all, out_weight_list = out[1], out[2]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
h_start = 0
|
||||
for j in range(h_start, self.len_seq, 1):
|
||||
i_start = j - len_win if j - len_win >= 0 else 0
|
||||
i_end = j + len_win if j + len_win < self.len_seq else self.len_seq - 1
|
||||
for k in range(i_start, i_end + 1):
|
||||
weight = (
|
||||
out_weight_list[i][j]
|
||||
if self.model_type == "AdaRNN"
|
||||
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
|
||||
)
|
||||
loss_transfer = loss_transfer + weight * criterion_transder.compute(
|
||||
out_list_s[i][:, j, :], out_list_t[i][:, k, :]
|
||||
)
|
||||
return fc_out, loss_transfer, out_weight_list
|
||||
|
||||
def gru_features(self, x, predict=False):
|
||||
x_input = x
|
||||
out = None
|
||||
out_lis = []
|
||||
out_weight_list = [] if (self.model_type == "AdaRNN") else None
|
||||
for i in range(self.num_layers):
|
||||
out, _ = self.features[i](x_input.float())
|
||||
x_input = out
|
||||
out_lis.append(out)
|
||||
if self.model_type == "AdaRNN" and predict == False:
|
||||
out_gate = self.process_gate_weight(x_input, i)
|
||||
out_weight_list.append(out_gate)
|
||||
return out, out_lis, out_weight_list
|
||||
|
||||
def process_gate_weight(self, out, index):
|
||||
x_s = out[0 : int(out.shape[0] // 2)]
|
||||
x_t = out[out.shape[0] // 2 : out.shape[0]]
|
||||
x_all = torch.cat((x_s, x_t), 2)
|
||||
x_all = x_all.view(x_all.shape[0], -1)
|
||||
weight = torch.sigmoid(self.bn_lst[index](self.gate[index](x_all.float())))
|
||||
weight = torch.mean(weight, dim=0)
|
||||
res = self.softmax(weight).squeeze()
|
||||
return res
|
||||
|
||||
def get_features(self, output_list):
|
||||
fea_list_src, fea_list_tar = [], []
|
||||
for fea in output_list:
|
||||
fea_list_src.append(fea[0 : fea.size(0) // 2])
|
||||
fea_list_tar.append(fea[fea.size(0) // 2 :])
|
||||
return fea_list_src, fea_list_tar
|
||||
|
||||
# For Boosting-based
|
||||
def forward_Boosting(self, x, weight_mat=None):
|
||||
out = self.gru_features(x)
|
||||
fea = out[0]
|
||||
if self.use_bottleneck:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
fc_out = self.fc_out(fea[:, -1, :]).squeeze()
|
||||
|
||||
out_list_all = out[1]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
if weight_mat is None:
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
|
||||
else:
|
||||
weight = weight_mat
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
for j in range(self.len_seq):
|
||||
loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :])
|
||||
loss_transfer = loss_transfer + weight[i, j] * loss_trans
|
||||
dist_mat[i, j] = loss_trans
|
||||
return fc_out, loss_transfer, dist_mat, weight
|
||||
|
||||
# For Boosting-based
|
||||
def update_weight_Boosting(self, weight_mat, dist_old, dist_new):
|
||||
epsilon = 1e-5
|
||||
dist_old = dist_old.detach()
|
||||
dist_new = dist_new.detach()
|
||||
ind = dist_new > dist_old + epsilon
|
||||
weight_mat[ind] = weight_mat[ind] * (1 + torch.sigmoid(dist_new[ind] - dist_old[ind]))
|
||||
weight_norm = torch.norm(weight_mat, dim=1, p=1)
|
||||
weight_mat = weight_mat / weight_norm.t().unsqueeze(1).repeat(1, self.len_seq)
|
||||
return weight_mat
|
||||
|
||||
def predict(self, x):
|
||||
out = self.gru_features(x, predict=True)
|
||||
fea = out[0]
|
||||
if self.use_bottleneck == True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
fc_out = self.fc_out(fea[:, -1, :]).squeeze()
|
||||
return fc_out
|
||||
|
||||
|
||||
class TransferLoss(object):
|
||||
def __init__(self, loss_type="cosine", input_dim=512):
|
||||
"""
|
||||
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
||||
"""
|
||||
self.loss_type = loss_type
|
||||
self.input_dim = input_dim
|
||||
|
||||
def compute(self, X, Y):
|
||||
"""Compute adaptation loss
|
||||
|
||||
Arguments:
|
||||
X {tensor} -- source matrix
|
||||
Y {tensor} -- target matrix
|
||||
|
||||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
if self.loss_type == "mmd_lin" or self.loss_type == "mmd":
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "coral":
|
||||
loss = CORAL(X, Y)
|
||||
elif self.loss_type == "cosine" or self.loss_type == "cos":
|
||||
loss = 1 - cosine(X, Y)
|
||||
elif self.loss_type == "kl":
|
||||
loss = kl_div(X, Y)
|
||||
elif self.loss_type == "js":
|
||||
loss = js(X, Y)
|
||||
elif self.loss_type == "mine":
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
|
||||
loss = mine_model(X, Y)
|
||||
elif self.loss_type == "adv":
|
||||
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
|
||||
elif self.loss_type == "mmd_rbf":
|
||||
mmdloss = MMD_loss(kernel_type="rbf")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "pairwise":
|
||||
pair_mat = pairwise_dist(X, Y)
|
||||
loss = torch.norm(pair_mat)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def cosine(source, target):
|
||||
source, target = source.mean(), target.mean()
|
||||
cos = nn.CosineSimilarity(dim=0)
|
||||
loss = cos(source, target)
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class ReverseLayerF(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, alpha):
|
||||
ctx.alpha = alpha
|
||||
return x.view_as(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
output = grad_output.neg() * ctx.alpha
|
||||
return output, None
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, input_dim=256, hidden_dim=256):
|
||||
super(Discriminator, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.dis1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.dis2 = nn.Linear(hidden_dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.dis1(x))
|
||||
x = self.dis2(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
domain_loss = nn.BCELoss()
|
||||
# !!! Pay attention to .cuda !!!
|
||||
adv_net = Discriminator(input_dim, hidden_dim).cuda()
|
||||
domain_src = torch.ones(len(source)).cuda()
|
||||
domain_tar = torch.zeros(len(target)).cuda()
|
||||
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
|
||||
reverse_src = ReverseLayerF.apply(source, 1)
|
||||
reverse_tar = ReverseLayerF.apply(target, 1)
|
||||
pred_src = adv_net(reverse_src)
|
||||
pred_tar = adv_net(reverse_tar)
|
||||
loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar)
|
||||
loss = loss_s + loss_t
|
||||
return loss
|
||||
|
||||
|
||||
def CORAL(source, target):
|
||||
d = source.size(1)
|
||||
ns, nt = source.size(0), target.size(0)
|
||||
|
||||
# source covariance
|
||||
tmp_s = torch.ones((1, ns)).cuda() @ source
|
||||
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
|
||||
|
||||
# target covariance
|
||||
tmp_t = torch.ones((1, nt)).cuda() @ target
|
||||
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
|
||||
|
||||
# frobenius norm
|
||||
loss = (cs - ct).pow(2).sum()
|
||||
loss = loss / (4 * d * d)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class MMD_loss(nn.Module):
|
||||
def __init__(self, kernel_type="linear", kernel_mul=2.0, kernel_num=5):
|
||||
super(MMD_loss, self).__init__()
|
||||
self.kernel_num = kernel_num
|
||||
self.kernel_mul = kernel_mul
|
||||
self.fix_sigma = None
|
||||
self.kernel_type = kernel_type
|
||||
|
||||
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
n_samples = int(source.size()[0]) + int(target.size()[0])
|
||||
total = torch.cat([source, target], dim=0)
|
||||
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||
L2_distance = ((total0 - total1) ** 2).sum(2)
|
||||
if fix_sigma:
|
||||
bandwidth = fix_sigma
|
||||
else:
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
|
||||
bandwidth /= kernel_mul ** (kernel_num // 2)
|
||||
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
|
||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||
return sum(kernel_val)
|
||||
|
||||
def linear_mmd(self, X, Y):
|
||||
delta = X.mean(axis=0) - Y.mean(axis=0)
|
||||
loss = delta.dot(delta.T)
|
||||
return loss
|
||||
|
||||
def forward(self, source, target):
|
||||
if self.kernel_type == "linear":
|
||||
return self.linear_mmd(source, target)
|
||||
elif self.kernel_type == "rbf":
|
||||
batch_size = int(source.size()[0])
|
||||
kernels = self.guassian_kernel(
|
||||
source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma
|
||||
)
|
||||
with torch.no_grad():
|
||||
XX = torch.mean(kernels[:batch_size, :batch_size])
|
||||
YY = torch.mean(kernels[batch_size:, batch_size:])
|
||||
XY = torch.mean(kernels[:batch_size, batch_size:])
|
||||
YX = torch.mean(kernels[batch_size:, :batch_size])
|
||||
loss = torch.mean(XX + YY - XY - YX)
|
||||
return loss
|
||||
|
||||
|
||||
class Mine_estimator(nn.Module):
|
||||
def __init__(self, input_dim=2048, hidden_dim=512):
|
||||
super(Mine_estimator, self).__init__()
|
||||
self.mine_model = Mine(input_dim, hidden_dim)
|
||||
|
||||
def forward(self, X, Y):
|
||||
Y_shffle = Y[torch.randperm(len(Y))]
|
||||
loss_joint = self.mine_model(X, Y)
|
||||
loss_marginal = self.mine_model(X, Y_shffle)
|
||||
ret = torch.mean(loss_joint) - torch.log(torch.mean(torch.exp(loss_marginal)))
|
||||
loss = -ret
|
||||
return loss
|
||||
|
||||
|
||||
class Mine(nn.Module):
|
||||
def __init__(self, input_dim=2048, hidden_dim=512):
|
||||
super(Mine, self).__init__()
|
||||
self.fc1_x = nn.Linear(input_dim, hidden_dim)
|
||||
self.fc1_y = nn.Linear(input_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, 1)
|
||||
|
||||
def forward(self, x, y):
|
||||
h1 = F.leaky_relu(self.fc1_x(x) + self.fc1_y(y))
|
||||
h2 = self.fc2(h1)
|
||||
return h2
|
||||
|
||||
|
||||
def pairwise_dist(X, Y):
|
||||
n, d = X.shape
|
||||
m, _ = Y.shape
|
||||
assert d == Y.shape[1]
|
||||
a = X.unsqueeze(1).expand(n, m, d)
|
||||
b = Y.unsqueeze(0).expand(n, m, d)
|
||||
return torch.pow(a - b, 2).sum(2)
|
||||
|
||||
|
||||
def pairwise_dist_np(X, Y):
|
||||
n, d = X.shape
|
||||
m, _ = Y.shape
|
||||
assert d == Y.shape[1]
|
||||
a = np.expand_dims(X, 1)
|
||||
b = np.expand_dims(Y, 0)
|
||||
a = np.tile(a, (1, m, 1))
|
||||
b = np.tile(b, (n, 1, 1))
|
||||
return np.power(a - b, 2).sum(2)
|
||||
|
||||
|
||||
def pa(X, Y):
|
||||
XY = np.dot(X, Y.T)
|
||||
XX = np.sum(np.square(X), axis=1)
|
||||
XX = np.transpose([XX])
|
||||
YY = np.sum(np.square(Y), axis=1)
|
||||
dist = XX + YY - 2 * XY
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def kl_div(source, target):
|
||||
if len(source) < len(target):
|
||||
target = target[: len(source)]
|
||||
elif len(source) > len(target):
|
||||
source = source[: len(target)]
|
||||
criterion = nn.KLDivLoss(reduction="batchmean")
|
||||
loss = criterion(source.log(), target)
|
||||
return loss
|
||||
|
||||
|
||||
def js(source, target):
|
||||
if len(source) < len(target):
|
||||
target = target[: len(source)]
|
||||
elif len(source) > len(target):
|
||||
source = source[: len(target)]
|
||||
M = 0.5 * (source + target)
|
||||
loss_1, loss_2 = kl_div(source, M), kl_div(target, M)
|
||||
return 0.5 * (loss_1 + loss_2)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user