11
.github/workflows/test.yml
vendored
@@ -43,16 +43,17 @@ jobs:
|
||||
- name: Lint with Black
|
||||
run: |
|
||||
cd ..
|
||||
python -m black qlib -l 120
|
||||
python -m black qlib -l 120 --check --diff
|
||||
|
||||
- name: Unit tests with Pytest
|
||||
run: |
|
||||
cd tests
|
||||
pytest . --durations=0
|
||||
|
||||
- name: Test data downloads and examples
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
|
||||
# cd examples
|
||||
# estimator -c estimator/estimator_config.yaml
|
||||
# jupyter nbconvert --execute estimator/analyze_from_estimator.ipynb --to html
|
||||
|
||||
- name: Test workflow by config
|
||||
run: |
|
||||
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
|
||||
112
README.md
@@ -9,7 +9,7 @@
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images/logo/1.png" />
|
||||
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
|
||||
</p>
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [Quant Model Zoo](#quant-model-zoo)
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [Quant Dataset Zoo](#quant-dataset-zoo)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
@@ -39,19 +41,17 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="http://fintech.msra.cn/images/framework.png" />
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Data layer` | `DataServer` focuses on providing high-performance infrastructure for users to manage and retrieve raw data. `DataEnhancement` will preprocess the data and provide the best dataset to be fed into the models. |
|
||||
| `Interday Model` | `Interday model` focuses on producing prediction scores (aka. _alpha_). Models are trained by `Model Creator` and managed by `Model Manager`. Users could choose one or multiple models for prediction. Multiple models could be combined with `Ensemble` module. |
|
||||
| `Interday Strategy` | `Portfolio Generator` will take prediction scores as input and output the orders based on the current position to achieve the target portfolio. |
|
||||
| `Intraday Trading` | `Order Executor` is responsible for executing orders output by `Interday Strategy` and returning the executed results. |
|
||||
| `Analysis` | Users could get a detailed analysis report of forecasting signals and portfolios in this part. |
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
* The modules with hand-drawn style are under development and will be released in the future.
|
||||
* The modules with dashed borders are highly user-customizable and extendible.
|
||||
@@ -128,50 +128,53 @@ Users could create the same dataset with it.
|
||||
-->
|
||||
|
||||
## Auto Quant Research Workflow
|
||||
Qlib provides a tool named `Estimator` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
|
||||
1. Quant Research Workflow: Run `Estimator` with [estimator_config.yaml](examples/estimator/estimator_config.yaml) as following. (*Please note that this may **not work** under MacOS with Python 3.8 due to the incompatibility of the `sacred` package we use with Python 3.8. We will fix this bug in the future.*)
|
||||
1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml)) as following.
|
||||
```bash
|
||||
cd examples # Avoid running program under the directory contains `qlib`
|
||||
estimator -c estimator/estimator_config.yaml
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
```
|
||||
The result of `Estimator` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
The result of `qrun` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
|
||||
```bash
|
||||
|
||||
risk
|
||||
excess_return_without_cost mean 0.000675
|
||||
std 0.005456
|
||||
annualized_return 0.170077
|
||||
information_ratio 1.963824
|
||||
max_drawdown -0.063646
|
||||
excess_return_with_cost mean 0.000479
|
||||
std 0.005453
|
||||
annualized_return 0.120776
|
||||
information_ratio 1.395116
|
||||
max_drawdown -0.071216
|
||||
'The following are analysis results of the excess return without cost.'
|
||||
risk
|
||||
mean 0.000708
|
||||
std 0.005626
|
||||
annualized_return 0.178316
|
||||
information_ratio 1.996555
|
||||
max_drawdown -0.081806
|
||||
'The following are analysis results of the excess return with cost.'
|
||||
risk
|
||||
mean 0.000512
|
||||
std 0.005626
|
||||
annualized_return 0.128982
|
||||
information_ratio 1.444287
|
||||
max_drawdown -0.091078
|
||||
|
||||
|
||||
|
||||
```
|
||||
Here are detailed documents for [Estimator](https://qlib.readthedocs.io/en/latest/component/estimator.html).
|
||||
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
|
||||
|
||||
2. Graphical Reports Analysis: Run `examples/estimator/analyze_from_estimator.ipynb` with `jupyter notebook` to get graphical reports
|
||||
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
|
||||
- Forecasting signal (model prediction) analysis
|
||||
- Cumulative Return of groups
|
||||

|
||||

|
||||
- Return distribution
|
||||

|
||||

|
||||
- Information Coefficient (IC)
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
- Auto Correlation of forecasting signal (model prediction)
|
||||

|
||||

|
||||
|
||||
- Portfolio analysis
|
||||
- Backtest return
|
||||

|
||||

|
||||
<!--
|
||||
- Score IC
|
||||

|
||||
@@ -184,21 +187,54 @@ Qlib provides a tool named `Estimator` to run the whole workflow automatically (
|
||||
-->
|
||||
|
||||
## Building Customized Quant Research Workflow by Code
|
||||
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/train_backtest_analyze.ipynb) is a demo for customized Quant research workflow by code
|
||||
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# Quant Model Zoo
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on lightgbm](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost](qlib/contrib/model/catboost_model.py)
|
||||
- [GBDT based on XGBoost](qlib/contrib/model/xgboost.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
|
||||
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
|
||||
<!-- - [TFT based on tensorflow](examples/benchmarks/TFT/tft.py) -->
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
## Run a single model
|
||||
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
|
||||
|
||||
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
|
||||
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored. (**Note**: the script will erase your previous experiment records created by running itself.)
|
||||
|
||||
Here is an example of running all the models for 10 iterations:
|
||||
```python
|
||||
python run_all_model.py 10
|
||||
```
|
||||
|
||||
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
|
||||
# Quant Dataset Zoo
|
||||
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`.
|
||||
- [Alpha360](./qlib/contrib/estimator/handler.py)
|
||||
- [Alpha158](./qlib/contrib/estimator/handler.py)
|
||||
|
||||
| Dataset | US Market | China Market |
|
||||
| -- | -- | -- |
|
||||
| [Alpha360](./qlib/contrib/data/handler.py) | √ | √ |
|
||||
| [Alpha158](./qlib/contrib/data/handler.py) | √ | √ |
|
||||
|
||||
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
|
||||
Your PR to build new Quant dataset is highly welcomed.
|
||||
|
||||
@@ -62,6 +62,7 @@ It sees the key of the redis lock has existed in your redis db now. You can use
|
||||
> select 1
|
||||
> flushdb
|
||||
|
||||
If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If so, try using ``flushall`` to clear all the keys.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
BIN
docs/_static/img/analysis/analysis_model_IC.png
vendored
|
Before Width: | Height: | Size: 40 KiB After Width: | Height: | Size: 33 KiB |
BIN
docs/_static/img/analysis/analysis_model_NDQ.png
vendored
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 23 KiB |
|
Before Width: | Height: | Size: 52 KiB After Width: | Height: | Size: 47 KiB |
|
Before Width: | Height: | Size: 66 KiB After Width: | Height: | Size: 63 KiB |
|
Before Width: | Height: | Size: 17 KiB After Width: | Height: | Size: 16 KiB |
|
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 16 KiB |
BIN
docs/_static/img/analysis/report.png
vendored
|
Before Width: | Height: | Size: 163 KiB After Width: | Height: | Size: 160 KiB |
|
Before Width: | Height: | Size: 53 KiB After Width: | Height: | Size: 46 KiB |
BIN
docs/_static/img/analysis/risk_analysis_bar.png
vendored
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 13 KiB |
|
Before Width: | Height: | Size: 56 KiB After Width: | Height: | Size: 54 KiB |
|
Before Width: | Height: | Size: 57 KiB After Width: | Height: | Size: 53 KiB |
BIN
docs/_static/img/analysis/risk_analysis_std.png
vendored
|
Before Width: | Height: | Size: 47 KiB After Width: | Height: | Size: 47 KiB |
BIN
docs/_static/img/analysis/score_ic.png
vendored
|
Before Width: | Height: | Size: 105 KiB After Width: | Height: | Size: 102 KiB |
BIN
docs/_static/img/framework.png
vendored
|
Before Width: | Height: | Size: 205 KiB After Width: | Height: | Size: 271 KiB |
@@ -1,4 +1,5 @@
|
||||
.. _alpha:
|
||||
|
||||
===========================
|
||||
Building Formulaic Alphas
|
||||
===========================
|
||||
@@ -49,7 +50,7 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.contrib.estimator.handler import QLibDataHandler
|
||||
>> from qlib.data.dataset.handler import QLibDataHandler
|
||||
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
|
||||
>> fields = [MACD_EXP] # MACD
|
||||
>> names = ['MACD']
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
.. _server:
|
||||
|
||||
=================================
|
||||
``Online`` & ``Offline`` mode
|
||||
=================================
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
.. _backtest:
|
||||
|
||||
============================================
|
||||
Intraday Trading: Model&Strategy Testing
|
||||
============================================
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
.. _data:
|
||||
|
||||
================================
|
||||
Data Layer: Data Framework&Usage
|
||||
Data Layer: Data Framework & Usage
|
||||
================================
|
||||
|
||||
Introduction
|
||||
@@ -14,7 +15,9 @@ The introduction of ``Data Layer`` includes the following parts.
|
||||
|
||||
- Data Preparation
|
||||
- Data API
|
||||
- Data Loader
|
||||
- Data Handler
|
||||
- Dataset
|
||||
- Cache
|
||||
- Data and Cache File Structure
|
||||
|
||||
@@ -30,13 +33,19 @@ Such data will be stored with filename suffix `.bin` (We'll call them `.bin` fil
|
||||
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the dataset as follows.
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
After running the above command, users can find china-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory.
|
||||
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
|
||||
After running the above command, users can find china-stock and us-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
|
||||
|
||||
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
|
||||
|
||||
@@ -48,12 +57,45 @@ Converting CSV Format into Qlib Format
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert data in CSV format into `.bin` files (Qlib format).
|
||||
|
||||
|
||||
Users can download the china-stock data in CSV format as follows for reference to the CSV format.
|
||||
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
|
||||
- Name the CSV file after a stock: `SH600000.csv`, `AAPL.csv` (not case sensitive).
|
||||
|
||||
- CSV file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --symbol_field_name symbol
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
symbol,close
|
||||
SH600000,120
|
||||
|
||||
- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --date_field_name date
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
symbol,date,close,open,volume
|
||||
SH600000,2020-11-01,120,121,12300000
|
||||
SH600000,2020-11-02,123,120,12300000
|
||||
|
||||
|
||||
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
|
||||
|
||||
@@ -61,6 +103,12 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv
|
||||
|
||||
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
|
||||
|
||||
For other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python dump_bin.py dump_all --help
|
||||
|
||||
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.
|
||||
|
||||
.. note::
|
||||
@@ -96,9 +144,8 @@ China-Stock Mode & US-Stock Mode
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=REG_CN)
|
||||
|
||||
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` does not provide a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Prepare data in CSV format
|
||||
- Convert data from CSV format to Qlib format, please refer to section `Converting CSV Format into Qlib Format <#converting-csv-format-into-qlib-format>`_.
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in US-stock mode
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
@@ -140,72 +187,97 @@ Filter
|
||||
Expression dynamic instrument filter. Filter the instruments based on a certain expression. An expression rule indicating a certain feature field is required.
|
||||
|
||||
- `basic features filter`: rule_expression = '$close/$open>5'
|
||||
- `cross-sectional features filter` : rule_expression = '$rank($close)<10'
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
|
||||
Reference
|
||||
-------------
|
||||
|
||||
To know more about ``Data API``, please refer to `Data API <../reference/api.html#data>`_.
|
||||
|
||||
|
||||
Data Loader
|
||||
=================
|
||||
|
||||
``Data Loader`` in ``Qlib`` is designed to load raw data from the original data source. It will be loaded and used in the ``Data Handler`` module.
|
||||
|
||||
QlibDataLoader
|
||||
---------------
|
||||
|
||||
The ``QlibDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from the data source.
|
||||
|
||||
Interface
|
||||
------------
|
||||
|
||||
Here are some interfaces of the ``QlibDataLoader`` class:
|
||||
|
||||
.. autoclass:: qlib.data.dataset.loader.QlibDataLoader
|
||||
:members: load, load_group_df
|
||||
|
||||
API
|
||||
-----------
|
||||
|
||||
To know more about ``Data Loader``, please refer to `Data Loader API <../reference/api.html#module-qlib.data.dataset.loader>`_.
|
||||
|
||||
|
||||
Data Handler
|
||||
=================
|
||||
|
||||
Users can use ``Data Handler`` in an automatic workflow by ``Estimator``, refer to `Estimator: Workflow Management <estimator.html>`_ for more details.
|
||||
The ``Data Handler`` module in ``Qlib`` is designed to handler those common data processing methods which will be used by most of the models.
|
||||
|
||||
Also, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data(standardization, remove NaN, etc.) and build datasets. It is a subclass of ``qlib.contrib.estimator.handler.BaseDataHandler``, which provides some interfaces as follows.
|
||||
Users can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management <workflow.html>`_ for more details.
|
||||
|
||||
Base Class & Interface
|
||||
DataHandlerLP
|
||||
--------------
|
||||
|
||||
In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
|
||||
|
||||
In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some leanable ``Processors`` which can learn the parameters of data processing. When new data comes in, these `trained` ``Processors`` can then infer on the new data and thus processing real-time data in an efficient way. More information about ``Processors`` will be listed in the next subsection.
|
||||
|
||||
|
||||
Interface
|
||||
----------------------
|
||||
|
||||
Qlib provides a base class `qlib.contrib.estimator.BaseDataHandler <../reference/api.html#qlib.contrib.estimator.handler.BaseDataHandler>`_, which provides the following interfaces:
|
||||
Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
|
||||
- `setup_feature`
|
||||
Implement the interface to load the data features.
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
|
||||
- `setup_label`
|
||||
Implement the interface to load the data labels and calculate the users' labels.
|
||||
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
|
||||
- `setup_processed_data`
|
||||
Implement the interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on.
|
||||
|
||||
Qlib also provides two functions to help users init the data handler, users can override them for users' needs.
|
||||
|
||||
- `_init_kwargs`
|
||||
Users can init the kwargs of the data handler in this function, some kwargs may be used when init the raw df.
|
||||
Kwargs are the other attributes in data.args, like dropna_label, dropna_feature
|
||||
|
||||
- `_init_raw_df`
|
||||
Users can init the raw df, feature names, and label names of data handler in this function.
|
||||
If the index of feature df and label df are not the same, users need to override this method to merge them (e.g. inner, left, right merge).
|
||||
|
||||
If users want to load features and labels by config, users can inherit ``qlib.contrib.estimator.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
|
||||
|
||||
|
||||
Usage
|
||||
--------------
|
||||
Processor
|
||||
----------
|
||||
|
||||
``Data Handler`` can be used as a single module, which provides the following mehtods:
|
||||
The ``Processor`` module in ``Qlib`` is designed to be learnable and it is responsible for handling data processing such as `normalization` and `drop none/nan features/labels`.
|
||||
|
||||
- `get_split_data`
|
||||
- According to the start and end dates, return features and labels of the pandas DataFrame type used for the 'Model'
|
||||
|
||||
- `get_rolling_data`
|
||||
- According to the start and end dates, and `rolling_period`, an iterator is returned, which can be used to traverse the features and labels used for rolling.
|
||||
``Qlib`` provides the following ``Processors``:
|
||||
|
||||
- ``DropnaProcessor``: `processor` that drops N/A features.
|
||||
- ``DropnaLabel``: `processor` that drops N/A labels.
|
||||
- ``TanhProcess``: `processor` that uses `tanh` to process noise data.
|
||||
- ``ProcessInf``: `processor` that handles infinity values, it will be replaces by the mean of the column.
|
||||
- ``Fillna``: `processor` that handles N/A values, which will fill the N/A value by 0 or other given number.
|
||||
- ``MinMaxNorm``: `processor` that applies min-max normalization.
|
||||
- ``ZscoreNorm``: `processor` that applies z-score normalization.
|
||||
- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.
|
||||
- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.
|
||||
- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
|
||||
|
||||
Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).
|
||||
|
||||
To know more about ``Processor``, please refer to `Processor API <../reference/api.html#module-qlib.data.dataset.processor>`_.
|
||||
|
||||
Example
|
||||
--------------
|
||||
|
||||
``Data Handler`` can be run with ``estimator`` by modifying the configuration file, and can also be used as a single module.
|
||||
``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module.
|
||||
|
||||
Know more about how to run ``Data Handler`` with ``Estimator``, please refer to `Estimator: Workflow Management <estimator.html>`_
|
||||
Know more about how to run ``Data Handler`` with ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_
|
||||
|
||||
Qlib provides implemented data handler `Alpha158`. The following example shows how to run `Alpha158` as a single module.
|
||||
|
||||
@@ -214,44 +286,54 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.estimator.handler import Alpha158
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
import qlib
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"dropna_label": True,
|
||||
"start_date": "2007-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"market": "csi300",
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi300",
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2007-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
}
|
||||
if __name__ == "__main__":
|
||||
qlib.init()
|
||||
h = Alpha158(**data_handler_config)
|
||||
|
||||
exampleDataHandler = Alpha158(**DATA_HANDLER_CONFIG)
|
||||
# get all the columns of the data
|
||||
print(h.get_cols())
|
||||
|
||||
# example of 'get_split_data'
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = exampleDataHandler.get_split_data(**TRAINER_CONFIG)
|
||||
# fetch all the labels
|
||||
print(h.fetch(col_set="label"))
|
||||
|
||||
# example of 'get_rolling_data'
|
||||
|
||||
for (x_train, y_train, x_validate, y_validate, x_test, y_test) in exampleDataHandler.get_rolling_data(**TRAINER_CONFIG):
|
||||
print(x_train, y_train, x_validate, y_validate, x_test, y_test)
|
||||
|
||||
|
||||
.. note:: (x_train, y_train, x_validate, y_validate, x_test, y_test) can be used as arguments for the `fit`, `predic``, and `score` methods of the ``Interday Model`` , please refer to `Model <model.html#base-class-interface>`_.
|
||||
|
||||
Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``.
|
||||
# fetch all the features
|
||||
print(h.fetch(col_set="feature"))
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.contrib.estimator.handler>`_.
|
||||
To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.data.dataset.handler>`_.
|
||||
|
||||
|
||||
Dataset
|
||||
=================
|
||||
|
||||
The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
|
||||
|
||||
The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the rights to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
|
||||
|
||||
The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most important interface of the class:
|
||||
|
||||
.. autoclass:: qlib.data.dataset.__init__.DatasetH
|
||||
:members:
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
|
||||
|
||||
|
||||
|
||||
Cache
|
||||
==========
|
||||
|
||||
@@ -1,706 +0,0 @@
|
||||
.. _estimator:
|
||||
=================================
|
||||
Estimator: Workflow Management
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
The components in `Qlib Framework <../introduction/introduction.html#framework>`_ are designed in a loosely-coupled way. Users could build their own Quant research workflow with these components like `Example <https://github.com/microsoft/qlib/blob/main/examples/train_and_backtest.py>`_
|
||||
|
||||
|
||||
Besides, ``Qlib`` provides more user-friendly interfaces named ``Estimator`` to automatically run the whole workflow defined by configuration. A concrete execution of the whole workflow is called an `experiment`.
|
||||
With ``Estimator``, user can easily run an `experiment`, which includes the following steps:
|
||||
|
||||
- Data
|
||||
- Loading
|
||||
- Processing
|
||||
- Slicing
|
||||
- Model
|
||||
- Training and inference(static or rolling)
|
||||
- Saving & loading
|
||||
- Evaluation(Back-testing)
|
||||
|
||||
For each `experiment`, ``Qlib`` will capture the model training details, performance evaluation results and basic information (e.g. names, ids). The captured data will be stored in backend-storage (disk or database).
|
||||
|
||||
Complete Example
|
||||
===================
|
||||
|
||||
Before getting into details, here is a complete example of ``Estimator``, which defines the workflow in typical Quant research.
|
||||
Below is a typical config file of ``Estimator``.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
experiment:
|
||||
name: estimator_example
|
||||
observer_type: file_storage
|
||||
mode: train
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
args:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
data:
|
||||
class: Alpha158
|
||||
args:
|
||||
dropna_label: True
|
||||
filter:
|
||||
market: csi500
|
||||
trainer:
|
||||
class: StaticTrainer
|
||||
args:
|
||||
rolling_period: 360
|
||||
train_start_date: 2007-01-01
|
||||
train_end_date: 2014-12-31
|
||||
validate_start_date: 2015-01-01
|
||||
validate_end_date: 2016-12-31
|
||||
test_start_date: 2017-01-01
|
||||
test_end_date: 2020-08-01
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
args:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
normal_backtest_args:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: SH000905
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
qlib_data:
|
||||
# when testing, please modify the following parameters according to the specific environment
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: "cn"
|
||||
|
||||
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
estimator -c configuration.yaml
|
||||
|
||||
.. note:: `estimator` will be placed in your $PATH directory when installing ``Qlib``.
|
||||
|
||||
|
||||
|
||||
Configuration File
|
||||
===================
|
||||
|
||||
Let's get into details of ``Estimator`` in this section.
|
||||
|
||||
Before using ``estimator``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.
|
||||
|
||||
Experiment Section
|
||||
--------------------
|
||||
|
||||
At first, the configuration file needs to contain a section named `experiment` about the basic information. This section describes how `estimator` tracks and persists current `experiment`. ``Qlib`` used `sacred`, a lightweight open-source tool, to configure, organize, generate logs, and manage experiment results. Partial behaviors of `sacred` will base on the `experiment` section.
|
||||
|
||||
Following files will be saved by `sacred` after `estimator` finish an `experiment`:
|
||||
|
||||
- `model.bin`, model binary file
|
||||
- `pred.pkl`, model prediction result file
|
||||
- `analysis.pkl`, backtest performance analysis file
|
||||
- `positions.pkl`, backtest position records file
|
||||
- `run`, the experiment information object, usually contains some meta information such as the experiment name, experiment date, etc.
|
||||
|
||||
Here is the typical configuration of `experiment section`
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
experiment:
|
||||
name: test_experiment
|
||||
observer_type: mongo
|
||||
mongo_url: mongodb://MONGO_URL
|
||||
db_name: public
|
||||
finetune: false
|
||||
exp_info_path: /home/test_user/exp_info.json
|
||||
mode: test
|
||||
loader:
|
||||
id: 677
|
||||
|
||||
|
||||
The meaning of each field is as follows:
|
||||
|
||||
- `name`
|
||||
The experiment name, str type, `sacred <https://github.com/IDSIA/sacred>_` will use this experiment name as an identifier for some important internal processes. Users can find this field in `run` object of `sacred`. The default value is `test_experiment`.
|
||||
|
||||
- `observer_type`
|
||||
Observer type, str type, there are two choices which include `file_storage` and `mongo` respectively. If `file_storage` is selected, all the above-mentioned managed contents will be stored in the `dir` directory, separated by the number of times of experiments as a subfolder. If it is `mongo`, the content will be stored in the database. The default is `file_storage`.
|
||||
|
||||
- For `file_storage` observer.
|
||||
- `dir`
|
||||
Directory URL, str type, directory for `file_storage` observer type, files captured and managed by sacred with `file_storage` observer will be saved to this directory, which is the same directory as `config.json` by default.
|
||||
|
||||
- For `mongo` observer.
|
||||
- `mongo_url`
|
||||
Database URL, str type, required if the observer type is `mongo`.
|
||||
|
||||
- `db_name`
|
||||
Database name, str type, required if the observer type is `mongo`.
|
||||
|
||||
- `finetune`
|
||||
``Estimator``'s behaviors to train models will base on this flag.
|
||||
If you just want to train models from scratch each time instead of based on existing models, please leave `finetune=false`. Otherwise please read the
|
||||
details below.
|
||||
|
||||
The following table is the processing logic for different situations.
|
||||
|
||||
========== =========================================== ==================================== =========================================== ==========================================
|
||||
. Static Rolling
|
||||
. finetune:true finetune:false finetune:true finetune:false
|
||||
========== =========================================== ==================================== =========================================== ==========================================
|
||||
Train - Need to provide model (Static or Rolling) - No need to provide model - Need to provide model (Static or Rolling) - Need to provide model (Static or Rolling)
|
||||
- The args in model section will be - The args in model section will be - The args in model section will be - The args in model section will be
|
||||
used for finetuning used for training used for finetuning used for finetuning
|
||||
- Update based on the provided model - Train model from scratch - Update based on the provided model - Based on the provided model update
|
||||
and parameters and parameters - Train model from scratch
|
||||
- **Each rolling time slice is based on** - **Train each rolling time slice**
|
||||
**a model updated from the previous** **separately**
|
||||
**time**
|
||||
Test - Model must exist, otherwise an exception will be raised.
|
||||
- For `StaticTrainer`, users need to train a model and record 'exp_info' for 'Test'.
|
||||
- For `RollingTrainer`, users need to train a set of models until the latest time, and record 'exp_info' for 'Test'.
|
||||
========== =============================================================================================================================================================================
|
||||
|
||||
.. note::
|
||||
|
||||
1. finetune parameters: share model.args parameters.
|
||||
|
||||
2. provide model: from `loader.model_index`, load the index of the model(starting from 0).
|
||||
|
||||
3. If `loader.model_index` is None:
|
||||
- In 'Static Finetune=True', if provide 'Rolling', use the last model to update.
|
||||
|
||||
- For `RollingTrainer` with Finetune=True.
|
||||
|
||||
- If `StaticTrainer` is used in loader, the model will be used for initialization for finetuning.
|
||||
|
||||
- If `RollingTrainer` is used in loader, the existing models will be used without any modification and the new models will be initialized with the model in the last period and finetune one by one.
|
||||
|
||||
|
||||
- `exp_info_path`
|
||||
save path of experiment info, str type, save the experiment info and model `prediction score` after the experiment is finished. Optional parameter, the default value is `<config_file_dir>/ex_name/exp_info.json`.
|
||||
|
||||
- `mode`
|
||||
`train` or `test`, str type.
|
||||
- `test mode` is designed for inference. Under `test mode`, it will load the model according to the parameters of `loader` and skip model training.
|
||||
- `train model` is the default value. It will train new models by default and
|
||||
Please note that when it fails to load model, it will fall back to `fit` model.
|
||||
|
||||
.. note::
|
||||
|
||||
if users choose ` test mode`, they need to make sure:
|
||||
- The loader of `test_start_date` must be less than or equal to the current `test_start_date`.
|
||||
- If other parameters of the `loader` model args are different, a warning will appear.
|
||||
|
||||
|
||||
- `loader`
|
||||
If you just want to train models from scratch each time instead of based on existing models, please ignore `loader` section. Otherwise please read the
|
||||
details below.
|
||||
|
||||
The `loader` section only works when the `mode` is `test` or `finetune` is `true`.
|
||||
|
||||
- `model_index`
|
||||
Model index, int type. The index of the loaded model in loader_models (starting at 0) for the first `finetune`. The default value is None.
|
||||
|
||||
- `exp_info_path`
|
||||
Loader model experiment info path, str type. If the field exists, the following parameters will be parsed from `exp_info_path`, and the following parameters will not work. One of this field and `id` must exist at least .
|
||||
|
||||
- `id`
|
||||
The experiment id of the model that needs to be loaded, int type. If the `mode` is `test`, this value is required. This field and `exp_info_path` must exist one.
|
||||
|
||||
- `name`
|
||||
The experiment name of the model that needs to be loaded, str type. The default value is the current experiment `name`.
|
||||
|
||||
- `observer_type`
|
||||
The experiment observer type of the model that needs to be loaded, str type. The default value is the current experiment `observer_type`.
|
||||
|
||||
.. note:: The observer type is a concept of the `sacred` module, which determines how files, standard input, and output which are managed by sacred are stored.
|
||||
|
||||
|
||||
- `file_storage`
|
||||
If `observer_type` is `file_storage`, the config may be as follows.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
experiment:
|
||||
name: test_experiment
|
||||
dir: <path to a directory> # default is dir of `config.yml`
|
||||
observer_type: file_storage
|
||||
- `mongo`
|
||||
If `observer_type` is `mongo`, the config may be as follows.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
experiment:
|
||||
name: test_experiment
|
||||
observer_type: mongo
|
||||
mongo_url: mongodb://MONGO_URL
|
||||
db_name: public
|
||||
|
||||
Users need to indicate `mongo_url` and `db_name` for a mongo observer.
|
||||
|
||||
.. note::
|
||||
|
||||
If users choose the mongo observer, they need to make sure:
|
||||
- Have an environment with the mongodb installed and a mongo database dedicated to storing the results of the experiments.
|
||||
- The python environment (the version of python and package) to run the experiments and the one to fetch the results are consistent.
|
||||
|
||||
Model Section
|
||||
-----------------
|
||||
|
||||
Users can use a specified model by configuration with hyper-parameters.
|
||||
|
||||
Custom Models
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Qlib supports custom models, but it must be a subclass of the `qlib.contrib.model.Model`, the config for a custom model may be as following.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
model:
|
||||
class: SomeModel
|
||||
module_path: /tmp/my_experment/custom_model.py
|
||||
args:
|
||||
loss: binary
|
||||
|
||||
|
||||
The class `SomeModel` should be in the module `custom_model`, and ``Qlib`` could parse the `module_path` to load the class.
|
||||
|
||||
To know more about ``Interday Model``, please refer to `Interday Model: Training & Prediction <model.html>`_.
|
||||
|
||||
Data Section
|
||||
-----------------
|
||||
|
||||
``Data Handler`` can be used to load raw data, prepare features and label columns, preprocess data (standardization, remove NaN, etc.), split training, validation, and test sets. It is a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`.
|
||||
|
||||
Users can use the specified data handler by config as follows.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
data:
|
||||
class: Alpha158
|
||||
args:
|
||||
start_date: 2005-01-01
|
||||
end_date: 2018-04-30
|
||||
dropna_label: True
|
||||
filter:
|
||||
market: csi500
|
||||
filter_pipeline:
|
||||
-
|
||||
class: NameDFilter
|
||||
module_path: qlib.filter
|
||||
args:
|
||||
name_rule_re: S(?!Z3)
|
||||
fstart_time: 2018-01-01
|
||||
fend_time: 2018-12-11
|
||||
-
|
||||
class: ExpressionDFilter
|
||||
module_path: qlib.filter
|
||||
args:
|
||||
rule_expression: $open/$factor<=45
|
||||
fstart_time: 2018-01-01
|
||||
fend_time: 2018-12-11
|
||||
|
||||
- `class`
|
||||
Data handler class, str type, which should be a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`, and implements 5 important interfaces for loading features, loading raw data, preprocessing raw data, slicing train, validation, and test data. The default value is `ALPHA360`. If users want to write a data handler to retrieve the data in ``Qlib``, `QlibDataHandler` is suggested.
|
||||
|
||||
- `module_path`
|
||||
The module path, str type, absolute url is also supported, indicates the path of the `class` implementation of the data processor class. The default value is `qlib.contrib.estimator.handler`.
|
||||
|
||||
- `args`
|
||||
Parameters used for ``Data Handler`` initialization.
|
||||
|
||||
- `train_start_date`
|
||||
Training start time, str type, the default value is `2005-01-01`.
|
||||
|
||||
- `start_date`
|
||||
Data start date, str type.
|
||||
|
||||
- `end_date`
|
||||
Data end date, str type. the data from start_date to end_date decides which part of data will be loaded in `datahandler`, users can only use these data in the following parts.
|
||||
|
||||
- `dropna_feature` (Optional in args)
|
||||
Drop Nan feature, bool type, the default value is False.
|
||||
|
||||
- `dropna_label` (Optional in args)
|
||||
Drop Nan label, bool type, the default value is True. Some multi-label tasks will use this.
|
||||
|
||||
- `normalize_method` (Optional in args)
|
||||
Normalize data by a given method. str type. ``Qlib`` gives two normalizing methods, `MinMax` and `Std`.
|
||||
If users want to build their own method, please override `_process_normalize_feature`.
|
||||
|
||||
- `filter`
|
||||
Dynamically filtering the stocks based on the filter pipeline.
|
||||
|
||||
- `market`
|
||||
index name, str type, the default value is `csi500`.
|
||||
|
||||
- `filter_pipeline`
|
||||
Filter rule list, list type, the default value is []. Can be customized according to users' needs.
|
||||
|
||||
- `class`
|
||||
Filter class name, str type.
|
||||
|
||||
- `module_path`
|
||||
The module path, str type.
|
||||
|
||||
- `args`
|
||||
The filter class parameters, these parameters are set according to the `class`, and all the parameters as kwargs to `class`.
|
||||
|
||||
Custom Data Handler
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Qlib support custom data handler, but it must be a subclass of the ``qlib.contrib.estimator.handler.BaseDataHandler``, the config for custom data handler may be as follows.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
data:
|
||||
class: SomeDataHandler
|
||||
module_path: /tmp/my_experment/custom_data_handler.py
|
||||
args:
|
||||
start_date: 2005-01-01
|
||||
end_date: 2018-04-30
|
||||
|
||||
The class `SomeDataHandler` should be in the module `custom_data_handler`, and ``Qlib`` could parse the `module_path` to load the class.
|
||||
|
||||
If users want to load features and labels by config, they can inherit ``qlib.contrib.estimator.handler.ConfigDataHandler``, ``Qlib`` also has provided some preprocess methods in this subclass.
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended, from which users can inherit the custom class. `QLibDataHandler` is also a subclass of `ConfigDataHandler`.
|
||||
|
||||
To know more about ``Data Handler``, please refer to `Data Framework&Usage <data.html>`_.
|
||||
|
||||
Trainer Section
|
||||
-----------------
|
||||
|
||||
Users can specify the trainer ``Trainer`` by the config file, which is a subclass of ``qlib.contrib.estimator.trainer.BaseTrainer`` and implement three important interfaces for training the model, restoring the model, and getting model predictions as follows.
|
||||
|
||||
- `train`
|
||||
Implement this interface to train the model.
|
||||
|
||||
- `load`
|
||||
Implement this interface to recover the model from disk.
|
||||
|
||||
- `get_pred`
|
||||
Implement this interface to get model prediction results.
|
||||
|
||||
Qlib have provided two implemented trainer,
|
||||
|
||||
- `StaticTrainer`
|
||||
The static trainer will be trained using the training, validation, and test data of the data processor static slicing.
|
||||
|
||||
- `RollingTrainer`
|
||||
The rolling trainer will use the rolling iterator of the data processor to split data for rolling training.
|
||||
|
||||
|
||||
Users can specify `trainer` with the configuration file:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
trainer:
|
||||
class: StaticTrainer # or RollingTrainer
|
||||
args:
|
||||
rolling_period: 360
|
||||
train_start_date: 2005-01-01
|
||||
train_end_date: 2014-12-31
|
||||
validate_start_date: 2015-01-01
|
||||
validate_end_date: 2016-06-30
|
||||
test_start_date: 2016-07-01
|
||||
test_end_date: 2017-07-31
|
||||
|
||||
- `class`
|
||||
Trainer class, which should be a subclass of `qlib.contrib.estimator.trainer.BaseTrainer`, and needs to implement three important interfaces, the default value is `StaticTrainer`.
|
||||
|
||||
- `module_path`
|
||||
The module path, str type, absolute url is also supported, indicates the path of the trainer class implementation.
|
||||
|
||||
- `args`
|
||||
Parameters used for ``Trainer`` initialization.
|
||||
|
||||
- `rolling_period`
|
||||
The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. Only used in `RollingTrainer`.
|
||||
|
||||
- `train_start_date`
|
||||
Training start time, str type.
|
||||
|
||||
- `train_end_date`
|
||||
Training end time, str type.
|
||||
|
||||
- `validate_start_date`
|
||||
Validation start time, str type.
|
||||
|
||||
- `validate_end_date`
|
||||
Validation end time, str type.
|
||||
|
||||
- `test_start_date`
|
||||
Test start time, str type.
|
||||
|
||||
- `test_end_date`
|
||||
Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`.
|
||||
|
||||
Custom Trainer
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Qlib supports custom trainer, but it must be a subclass of the `qlib.contrib.estimator.trainer.BaseTrainer`, the config for a custom trainer may be as following:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
trainer:
|
||||
class: SomeTrainer
|
||||
module_path: /tmp/my_experment/custom_trainer.py
|
||||
args:
|
||||
train_start_date: 2005-01-01
|
||||
train_end_date: 2014-12-31
|
||||
validate_start_date: 2015-01-01
|
||||
validate_end_date: 2016-06-30
|
||||
test_start_date: 2016-07-01
|
||||
test_end_date: 2017-07-31
|
||||
|
||||
|
||||
The class `SomeTrainer` should be in the module `custom_trainer`, and ``Qlib`` could parse the `module_path` to load the class.
|
||||
|
||||
Strategy Section
|
||||
-----------------
|
||||
|
||||
Users can specify strategy through a config file, for example:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
strategy :
|
||||
class: TopkDropoutStrategy
|
||||
args:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
|
||||
- `class`
|
||||
The strategy class, str type, should be a subclass of `qlib.contrib.strategy.strategy.BaseStrategy`. The default value is `TopkDropoutStrategy`.
|
||||
|
||||
- `module_path`
|
||||
The module location, str type, absolute url is also supported, and absolute path is also supported, indicates the location of the policy class implementation.
|
||||
|
||||
- `args`
|
||||
Parameters used for ``Trainer`` initialization.
|
||||
|
||||
- `topk`
|
||||
The number of stocks in the portfolio
|
||||
|
||||
- `n_drop`
|
||||
Number of stocks to be replaced in each trading date
|
||||
|
||||
Custom Strategy
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Qlib supports custom strategy, but it must be a subclass of the ``qlib.contrib.strategy.strategy.BaseStrategy``, the config for custom strategy may be as following:
|
||||
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
strategy :
|
||||
class: SomeStrategy
|
||||
module_path: /tmp/my_experment/custom_strategy.py
|
||||
|
||||
The class `SomeStrategy` should be in the module `custom_strategy`, and ``Qlib`` could parse the `module_path` to load the class.
|
||||
|
||||
To know more about ``Strategy``, please refer to `Strategy <strategy.html>`_.
|
||||
|
||||
Backtest Section
|
||||
-----------------
|
||||
|
||||
Users can specify `backtest` through a config file, for example:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
backtest :
|
||||
normal_backtest_args:
|
||||
topk: 50
|
||||
benchmark: SH000905
|
||||
account: 500000
|
||||
deal_price: close
|
||||
min_cost: 5
|
||||
subscribe_fields:
|
||||
- $close
|
||||
- $change
|
||||
- $factor
|
||||
|
||||
- `normal_backtest_args`
|
||||
Normal backtest parameters. All the parameters in this section will be passed to the ``qlib.contrib.evaluate.backtest`` function in the form of `**kwargs`.
|
||||
|
||||
- `benchmark`
|
||||
Stock index symbol, str, or list type, the default value is `None`.
|
||||
|
||||
.. note::
|
||||
|
||||
* If `benchmark` is None, it will use the average change of the day of all stocks in 'pred' as the 'bench'.
|
||||
|
||||
* If `benchmark` is list, it will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
|
||||
* If `benchmark` is str, it will use the daily change as the 'bench'.
|
||||
|
||||
|
||||
- `account`
|
||||
Backtest initial cash, integer type. The `account` in `strategy` section is deprecated. It only works when `account` is not set in `backtest` section. It will be overridden by `account` in the `backtest` section. The default value is 1e9.
|
||||
|
||||
- `deal_price`
|
||||
Order transaction price field, str type, the default value is close.
|
||||
|
||||
- `min_cost`
|
||||
Min transaction cost, float type, the default value is 5.
|
||||
|
||||
- `subscribe_fields`
|
||||
Subscribe quote fields, array type, the default value is [`deal_price`, $close, $change, $factor].
|
||||
|
||||
|
||||
Qlib Data Section
|
||||
--------------------
|
||||
|
||||
The `qlib_data` field describes the parameters of qlib initialization.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
qlib_data:
|
||||
# when testing, please modify the following parameters according to the specific environment
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: "cn"
|
||||
|
||||
- `provider_uri`
|
||||
Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
|
||||
- `region`
|
||||
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
|
||||
- `redis_host`
|
||||
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
|
||||
The lock and cache mechanism relies on redis.
|
||||
- `redis_port`
|
||||
Type: int, optional parameter(default: 6379), port of `redis`
|
||||
|
||||
.. note::
|
||||
|
||||
The value of `region` should be aligned with the data stored in `provider_uri`. Currently, ``scripts/get_data.py`` only provides China stock market data. If users want to use the US stock market data, they should prepare their own US-stock data in `provider_uri` and switch to US-stock mode.
|
||||
|
||||
.. note::
|
||||
|
||||
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <data.html#cache>`_ for details.
|
||||
|
||||
|
||||
Please refer to `Initialization <../start/initialization.html>`_.
|
||||
|
||||
Experiment Result
|
||||
===================
|
||||
|
||||
Form of Experimental Result
|
||||
----------------------------
|
||||
The result of the experiment is also the result of the ``Intraday Trading(Backtest)``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
|
||||
|
||||
|
||||
Get Experiment Result
|
||||
----------------------------
|
||||
|
||||
Base Class & Interface
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Users can check the experiment results from file storage directly, or check the experiment results from the database, or get the experiment results through two interfaces of a base class `Fetcher` provided by ``Qlib``.
|
||||
|
||||
The `Fetcher` provides the following interface
|
||||
- `get_experiments(self, exp_name=None):`
|
||||
The interface takes one parameters. The `exp_name` is the experiment name, the default is all experiments. Users can get the returned dictionary with a list of ids and test end date as follows.
|
||||
|
||||
.. code-block:: JSON
|
||||
|
||||
{
|
||||
"ex_a": [
|
||||
{
|
||||
"id": 1,
|
||||
"test_end_date": "2017-01-01"
|
||||
}
|
||||
],
|
||||
"ex_b": [
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
- `get_experiment(exp_name, exp_id, fields=None)`
|
||||
The interface takes three parameters. The first parameter is the experiment name, the second parameter is the experiment id, and the third parameter is a list of fields. The default value of `fields` is None, which means all fields.
|
||||
|
||||
|
||||
.. note::
|
||||
Currently supported fields:
|
||||
['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
|
||||
|
||||
Users can get the returned dictionary as follows.
|
||||
|
||||
.. code-block:: JSON
|
||||
|
||||
{
|
||||
'analysis': analysis_df,
|
||||
'pred': pred_df,
|
||||
'positions': positions_dic,
|
||||
'report_normal': report_normal_df,
|
||||
}
|
||||
|
||||
Implemented `Fetcher` s & Examples
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
``Qlib`` provides two implemented `Fetcher` s as follows.
|
||||
|
||||
`FileFetcher`
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
The `FileFetcher` is a subclass of `Fetcher`, which could fetch files from `file_storage` observer. The following is an example:
|
||||
.. code-block:: python
|
||||
|
||||
>>> from qlib.contrib.estimator.fetcher import FileFetcher
|
||||
>>> f = FileFetcher(experiments_dir=r'./')
|
||||
>>> print(f.get_experiments())
|
||||
{
|
||||
'test_experiment': [
|
||||
{
|
||||
'id': '1',
|
||||
'config': ...
|
||||
},
|
||||
{
|
||||
'id': '2',
|
||||
'config': ...
|
||||
},
|
||||
{
|
||||
'id': '3',
|
||||
'config': ...
|
||||
}
|
||||
]
|
||||
}
|
||||
>>> print(f.get_experiment('test_experiment', '1'))
|
||||
risk
|
||||
excess_return_without_cost mean 0.000605
|
||||
std 0.005481
|
||||
annualized_return 0.152373
|
||||
information_ratio 1.751319
|
||||
max_drawdown -0.059055
|
||||
excess_return_with_cost mean 0.000410
|
||||
std 0.005478
|
||||
annualized_return 0.103265
|
||||
information_ratio 1.187411
|
||||
max_drawdown -0.075024
|
||||
|
||||
|
||||
|
||||
`MongoFetcher`
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
The `FileFetcher` is a subclass of `Fetcher`, which could fetch files from `mongo` observer. Users should initialize the fetcher with `mongo_url`. The following is an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>>> from qlib.contrib.estimator.fetcher import MongoFetcher
|
||||
>>> f = MongoFetcher(mongo_url=..., db_name=...)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
.. _model:
|
||||
|
||||
============================================
|
||||
Interday Model: Model Training & Prediction
|
||||
============================================
|
||||
@@ -6,164 +7,138 @@ Interday Model: Model Training & Prediction
|
||||
Introduction
|
||||
===================
|
||||
|
||||
``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``Estimator``, please refer to `Estimator: Workflow Management <estimator.html>`_.
|
||||
``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_.
|
||||
|
||||
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Interday Model`` can be used as an independent module also.
|
||||
|
||||
Base Class & Interface
|
||||
======================
|
||||
|
||||
``Qlib`` provides a base class `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ from which all models should inherit.
|
||||
``Qlib`` provides a base class `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ from which all models should inherit.
|
||||
|
||||
The base class provides the following interfaces:
|
||||
|
||||
- `__init__(**kwargs)`
|
||||
- Initialization.
|
||||
- If users use ``Estimator`` to start an `experiment`, the parameter of `__init__` method shoule be consistent with the hyperparameters in the configuration file.
|
||||
|
||||
- `fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs)`
|
||||
- `fit(self, dataset, **kwargs)`
|
||||
- Train model.
|
||||
- Parameter:
|
||||
- `x_train`, pd.DataFrame type, train feature
|
||||
The following example explains the value of `x_train`:
|
||||
- `dataset`, ``Qlib``'s ``DatasetH`` type. For more information about ``DatasetH``, users can refer to the related document: `Qlib Dataset <../component/data.html#dataset>`_.
|
||||
The `dataset` is passed into the `model`'s method because there are some unique data preprocessing procedures for each, we want to give each model maximum flexibility to handle the data that is suitable for their own.
|
||||
The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
KMID KLEN KMID2 KUP KUP2
|
||||
instrument datetime
|
||||
SH600004 2012-01-04 0.000000 0.017685 0.000000 0.012862 0.727275
|
||||
2012-01-05 -0.006473 0.025890 -0.250001 0.012945 0.499998
|
||||
2012-01-06 0.008117 0.019481 0.416666 0.008117 0.416666
|
||||
2012-01-09 0.016051 0.025682 0.624998 0.006421 0.250001
|
||||
2012-01-10 0.017323 0.026772 0.647057 0.003150 0.117648
|
||||
... ... ... ... ... ...
|
||||
SZ300273 2014-12-25 -0.005295 0.038697 -0.136843 0.016293 0.421052
|
||||
2014-12-26 -0.022486 0.041701 -0.539215 0.002453 0.058824
|
||||
2014-12-29 -0.031526 0.039092 -0.806451 0.000000 0.000000
|
||||
2014-12-30 -0.010000 0.032174 -0.310811 0.013913 0.432433
|
||||
2014-12-31 0.010917 0.020087 0.543479 0.001310 0.065216
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
`x_train` is a pandas DataFrame, whose index is MultiIndex <instrument(str), datetime(pd.Timestamp)>. Each column of `x_train` corresponds to a feature, and the column name is the feature name.
|
||||
|
||||
.. note::
|
||||
|
||||
The number and names of the columns are determined by the data handler, please refer to `Data Handler <data.html#data-handler>`_ and `Estimator Data Section <estimator.html#data-section>`_.
|
||||
|
||||
- `y_train`, pd.DataFrame type, train label
|
||||
The following example explains the value of `y_train`:
|
||||
# get features and labels
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
LABEL
|
||||
instrument datetime
|
||||
SH600004 2012-01-04 -0.798456
|
||||
2012-01-05 -1.366716
|
||||
2012-01-06 -0.491026
|
||||
2012-01-09 0.296900
|
||||
2012-01-10 0.501426
|
||||
... ...
|
||||
SZ300273 2014-12-25 -0.465540
|
||||
2014-12-26 0.233864
|
||||
2014-12-29 0.471368
|
||||
2014-12-30 0.411914
|
||||
2014-12-31 1.342723
|
||||
|
||||
`y_train` is a pandas DataFrame, whose index is MultiIndex <instrument(str), datetime(pd.Timestamp)>. The `LABEL` column represents the value of train label.
|
||||
|
||||
.. note::
|
||||
|
||||
The number and names of the columns are determined by the ``Data Handler``, please refer to `Data Handler <data.html#data-handler>`_.
|
||||
|
||||
- `x_valid`, pd.DataFrame type, validation feature
|
||||
The format of `x_valid` is same as `x_train`
|
||||
|
||||
|
||||
- `y_valid`, pd.DataFrame type, validation label
|
||||
The format of `y_valid` is same as `y_train`
|
||||
|
||||
- `w_train`(Optional args, default is None), pd.DataFrame type, train weight
|
||||
`w_train` is a pandas DataFrame, whose shape and index is same as `x_train`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
|
||||
|
||||
- `w_train`(Optional args, default is None), pd.DataFrame type, validation weight
|
||||
`w_train` is a pandas DataFrame, whose shape and index is the same as `x_valid`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
|
||||
|
||||
- `predict(self, x_test, **kwargs)`
|
||||
- Predict test data 'x_test'
|
||||
- Parameter:
|
||||
- `x_test`, pd.DataFrame type, test features
|
||||
The form of `x_test` is same as `x_train` in 'fit' method.
|
||||
- Return:
|
||||
- `label`, np.ndarray type, test label
|
||||
The label of `x_test` that predicted by model.
|
||||
|
||||
- `score(self, x_test, y_test, w_test=None, **kwargs)`
|
||||
- Evaluate model with test feature/label
|
||||
- Parameter:
|
||||
- `x_test`, pd.DataFrame type, test feature
|
||||
The format of `x_test` is same as `x_train` in `fit` method.
|
||||
# get weights
|
||||
try:
|
||||
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
|
||||
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
|
||||
except KeyError as e:
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
- `x_test`, pd.DataFrame type, test label
|
||||
The format of `y_test` is same as `y_train` in `fit` method.
|
||||
- `predict(self, dataset, **kwargs)`
|
||||
- Predict test data.
|
||||
- Parameter:
|
||||
- `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
|
||||
- Returns:
|
||||
- Predic results with type: `pandas.Series`.
|
||||
|
||||
- `w_test`, pd.DataFrame type, test weight
|
||||
The format of `w_test` is same as `w_train` in `fit` method.
|
||||
- Return: float type, evaluation score
|
||||
- `finetune(self, dataset, **kwargs)`
|
||||
- Finetune the model.
|
||||
- Parameter:
|
||||
- `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
|
||||
|
||||
For other interfaces such as `save`, `load`, `finetune`, please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
|
||||
|
||||
For other interfaces such as `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
|
||||
|
||||
Example
|
||||
==================
|
||||
|
||||
``Qlib`` provides ``LightGBM`` and ``DNN`` models as the baseline, the following steps show how to run`` LightGBM`` as an independent module.
|
||||
``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. The following steps show how to run`` LightGBM`` as an independent module.
|
||||
|
||||
- Initialize ``Qlib`` with `qlib.init` first, please refer to `Initialization <../start/initialization.html>`_.
|
||||
- Run the following code to get the `prediction score` `pred_score`
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.estimator.handler import Alpha158
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"dropna_label": True,
|
||||
"start_date": "2007-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"market": MARKET,
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2007-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(
|
||||
**DATA_HANDLER_CONFIG
|
||||
).get_split_data(**TRAINER_CONFIG)
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
# train
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
MODEL_CONFIG = {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
}
|
||||
# use default model
|
||||
model = LGBModel(**MODEL_CONFIG)
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
_pred = model.predict(x_test)
|
||||
pred_score = pd.DataFrame(index=_pred.index)
|
||||
pred_score["score"] = _pred.iloc(axis=1)[0]
|
||||
|
||||
.. note:: `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.
|
||||
.. note::
|
||||
|
||||
`Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.
|
||||
`SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow <recorder.html#record-template>`_.
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
|
||||
@@ -175,4 +150,4 @@ Qlib supports custom models. If users are interested in customizing their own mo
|
||||
|
||||
API
|
||||
===================
|
||||
Please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
|
||||
Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
|
||||
|
||||
97
docs/component/recorder.rst
Normal file
@@ -0,0 +1,97 @@
|
||||
.. _recorder:
|
||||
|
||||
====================================
|
||||
Qlib Recorder: Experiment Management
|
||||
====================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
``Qlib`` contains an experiment management system named ``QlibRecorder``, which is designed to help users handle experiment and analysis results in an efficient way.
|
||||
|
||||
There are three components of the system:
|
||||
|
||||
- `ExperimentManager`
|
||||
a class that manages experiments.
|
||||
|
||||
- `Experiment`
|
||||
a class of experiment, and each instance of it is responsible for a single experiment.
|
||||
|
||||
- `Recorder`
|
||||
a class of recorder, and each instance of it is responsible for a single run.
|
||||
|
||||
Here is a general view of the structure of the system:
|
||||
|
||||
.. code-block::
|
||||
|
||||
ExperimentManager
|
||||
- Experiment 1
|
||||
- Recorder 1
|
||||
- Recorder 2
|
||||
- ...
|
||||
- Experiment 2
|
||||
- Recorder 1
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
|
||||
Currently, the components of this experiment management system are implemented using the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
|
||||
Qlib Recorder
|
||||
===================
|
||||
``QlibRecorder`` provides a high level API for users to use the experiment management system. The interfaces are wrapped in the variable ``R`` in ``Qlib``, and users can directly use ``R`` to interact with the system. The following command shows how to import ``R`` in Python:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.workflow import R
|
||||
|
||||
``QlibRecorder`` includes several common API for managing `experiments` and `recorders` within a workflow. For more available APIs, please refer to the following section about `Experiment Manager`, `Experiment` and `Recorder`.
|
||||
|
||||
Here are the available interfaces of ``QlibRecorder``:
|
||||
|
||||
.. autoclass:: qlib.workflow.__init__.QlibRecorder
|
||||
:members:
|
||||
|
||||
Experiment Manager
|
||||
===================
|
||||
|
||||
The ``ExpManager`` module in ``Qlib`` is responsible for managing different experiments. Most of the APIs of ``ExpManager`` are similar to ``QlibRecorder``, and the most important API will be the ``get_exp`` method. User can directly refer to the documents above for some detailed information about how to use the ``get_exp`` method.
|
||||
|
||||
.. autoclass:: qlib.workflow.expm.ExpManager
|
||||
:members: get_exp, list_experiments
|
||||
|
||||
For other interfaces such as `create_exp`, `delete_exp`, please refer to `Experiment Manager API <../reference/api.html#experiment-manager>`_.
|
||||
|
||||
Experiment
|
||||
===================
|
||||
|
||||
The ``Experiment`` class is solely responsible for a single experiment, and it will handle any operations that are related to an experiment. Basic methods such as `start`, `end` an experiment are included. Besides, methods related to `recorders` are also available: such methods include `get_recorder` and `list_recorders`.
|
||||
|
||||
.. autoclass:: qlib.workflow.exp.Experiment
|
||||
:members: get_recorder, list_recorders
|
||||
|
||||
For other interfaces such as `search_records`, `delete_recorder`, please refer to `Experiment API <../reference/api.html#experiment>`_.
|
||||
|
||||
Recorder
|
||||
===================
|
||||
|
||||
The ``Recorder`` class is responsible for a single recorder. It will handle some detailed operations such as ``log_metrics``, ``log_params`` of a single run. It is designed to help user to easily track results and things being generated during a run.
|
||||
|
||||
Here are some important APIs that are not included in the ``QlibRecorder``:
|
||||
|
||||
.. autoclass:: qlib.workflow.recorder.Recorder
|
||||
:members: list_artifacts, list_metrics, list_params, list_tags
|
||||
|
||||
For other interfaces such as `save_objects`, `load_object`, please refer to `Recorder API <../reference/api.html#recorder>`_.
|
||||
|
||||
Record Template
|
||||
===================
|
||||
|
||||
The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
|
||||
|
||||
- ``SignalRecord``: This class generates the `preidction` results of the model.
|
||||
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
|
||||
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
@@ -1,12 +1,13 @@
|
||||
.. _report:
|
||||
|
||||
==========================================
|
||||
Aanalysis: Evaluation & Results Analysis
|
||||
Analysis: Evaluation & Results Analysis
|
||||
==========================================
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
``Aanalysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
|
||||
``Analysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
|
||||
|
||||
- analysis_position
|
||||
- report_graph
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
.. _strategy:
|
||||
|
||||
========================================
|
||||
Interday Strategy: Portfolio Management
|
||||
========================================
|
||||
|
||||
280
docs/component/workflow.rst
Normal file
@@ -0,0 +1,280 @@
|
||||
.. _workflow:
|
||||
|
||||
=================================
|
||||
Workflow: Workflow Management
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
The components in `Qlib Framework <../introduction/introduction.html#framework>`_ are designed in a loosely-coupled way. Users could build their own Quant research workflow with these components like `Example <https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py>`_.
|
||||
|
||||
|
||||
Besides, ``Qlib`` provides more user-friendly interfaces named ``qrun`` to automatically run the whole workflow defined by configuration. A concrete execution of the whole workflow is called an `experiment`.
|
||||
With ``qrun``, user can easily run an `experiment`, which includes the following steps:
|
||||
|
||||
- Data
|
||||
- Loading
|
||||
- Processing
|
||||
- Slicing
|
||||
- Model
|
||||
- Training and inference
|
||||
- Saving & loading
|
||||
- Evaluation
|
||||
- Forecast signal analysis
|
||||
- Backtest
|
||||
|
||||
For each `experiment`, ``Qlib`` has a complete system to tracking all the information as well as artifacts generated during training, inference and evaluation phase. For more information about how Qlib handles `experiment`, please refer to the related document: `Recorder: Experiment Management <../component/recorder.html>`_.
|
||||
|
||||
Complete Example
|
||||
===================
|
||||
|
||||
Before getting into details, here is a complete example of ``qrun``, which defines the workflow in typical Quant research.
|
||||
Below is a typical config file of ``qrun``.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
qrun -c configuration.yaml
|
||||
|
||||
.. note::
|
||||
|
||||
`qrun` will be placed in your $PATH directory when installing ``Qlib``.
|
||||
|
||||
|
||||
Configuration File
|
||||
===================
|
||||
|
||||
Let's get into details of ``qrun`` in this section.
|
||||
|
||||
Before using ``qrun``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.
|
||||
|
||||
Qlib Data Section
|
||||
--------------------
|
||||
|
||||
At first, the configuration file needs to contain several basic parameters about the data, which will be used for qlib initialization, data handling and backtest.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
|
||||
The meaning of each field is as follows:
|
||||
|
||||
- `provider_uri`
|
||||
Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
|
||||
|
||||
- `region`
|
||||
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
|
||||
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
|
||||
|
||||
.. note::
|
||||
|
||||
The value of `region` should be aligned with the data stored in `provider_uri`.
|
||||
|
||||
- `market`
|
||||
Type: str. Index name, the default value is `csi500`.
|
||||
|
||||
- `benchmark`
|
||||
Type: str, list or pandas.Series. Stock index symbol, the default value is `SH000905`.
|
||||
|
||||
.. note::
|
||||
|
||||
* If `benchmark` is str, it will use the daily change as the 'bench'.
|
||||
|
||||
* If `benchmark` is list, it will use the daily average change of the stock pool in the list as the 'bench'.
|
||||
|
||||
* If `benchmark` is pandas.Series, whose `index` is trading date and the value T is the change from T-1 to T, it will be directly used as the 'bench'. An example is as following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
|
||||
2017-01-04 0.011693
|
||||
2017-01-05 0.000721
|
||||
2017-01-06 -0.004322
|
||||
2017-01-09 0.006874
|
||||
2017-01-10 -0.003350
|
||||
.. note::
|
||||
|
||||
The symbol `&` in `yaml` file stands for an anchor of a field, which is useful when another fields include this parameter as part of the value. Taking the configuration file above as an example, users can directly change the value of `market` and `benchmark` without traversing the entire configuration file.
|
||||
|
||||
Model Section
|
||||
--------------------
|
||||
|
||||
In the `task` field, the `model` section describes the parameters of the model to be used for training and inference. For more information about the base ``Model`` class, please refer to `Qlib Model <../component/model.html>`_.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
|
||||
The meaning of each field is as follows:
|
||||
|
||||
- `class`
|
||||
Type: str. The name for the model class.
|
||||
|
||||
- `module_path`
|
||||
Type: str. The path for the model in qlib.
|
||||
|
||||
- `kwargs`
|
||||
The keywords arguments for the model. Please refer to the specific model implementation for more information: `models <https://github.com/microsoft/qlib/blob/main/qlib/contrib/model>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
``Qlib`` provides a util named: ``init_instance_by_config`` to initialize any class inside ``Qlib`` with the configuration includes the fields: `class`, `module_path` and `kwargs`.
|
||||
|
||||
Dataset Section
|
||||
--------------------
|
||||
|
||||
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Model <../component/data.html#dataset>`_.
|
||||
|
||||
The keywords arguments configuration of the ``DataHandler`` is as follows:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
|
||||
Users can refer to the document of `DataHandler <../component/data.html#datahandler>`_ for more information about the meaning of each field in the configuration.
|
||||
|
||||
Here is the configuration for the ``Dataset`` module which will take care of data preprossing and slicing during the training and testing phase.
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
|
||||
Record Section
|
||||
--------------------
|
||||
|
||||
The `record` field is about the parameters the ``Record`` module in ``Qlib``. ``Record`` is responsible for generating certain analysis and evaluation results such as `prediction`, `information Coefficient (IC)` and `backtest`.
|
||||
|
||||
The following script is the configuration of `backtest` and the `strategy` used in `backtest`:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
For more information about the meaning of each field in configuration of `strategy` and `backtest`, users can look up the documents: `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
Here is the configuration details of different `Record Template` such as ``SignalRecord`` and ``PortAnaRecord``:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
|
||||
For more information about the ``Record`` module in ``Qlib``, user can refer to the related document: `Record <../component/recorder.html#record-template>`_.
|
||||
@@ -124,7 +124,7 @@ html_theme_options = {
|
||||
"logo_only": True,
|
||||
"collapse_navigation": False,
|
||||
"display_version": False,
|
||||
"navigation_depth": 3,
|
||||
"navigation_depth": 4,
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
|
||||
@@ -35,12 +35,13 @@ Document Structure
|
||||
:maxdepth: 3
|
||||
:caption: COMPONENTS:
|
||||
|
||||
Estimator: Workflow Management <component/estimator.rst>
|
||||
Workflow: Workflow Management <component/workflow.rst>
|
||||
Data Layer: Data Framework&Usage <component/data.rst>
|
||||
Interday Model: Model Training & Prediction <component/model.rst>
|
||||
Interday Strategy: Portfolio Management <component/strategy.rst>
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Aanalysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
@@ -48,6 +49,7 @@ Document Structure
|
||||
|
||||
Building Formulaic Alphas <advanced/alpha.rst>
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
:caption: REFERENCE:
|
||||
|
||||
@@ -21,27 +21,27 @@ Framework
|
||||
|
||||
At the module level, Qlib is a platform that consists of above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
|
||||
====================== ==============================================================================
|
||||
Name Description
|
||||
====================== ==============================================================================
|
||||
`Data layer` `DataServer` focuses on providing high-performance infrastructure for users to
|
||||
manage and retrieve raw data. `DataEnhancement` will preprocess the data and
|
||||
provide the best dataset to be fed into the models.
|
||||
|
||||
`Interday Model` `Interday model` focuses on producing prediction scores (aka. `alpha`). Models
|
||||
are trained by `Model Creator` and managed by `Model Manager`. Users could
|
||||
choose one or multiple models for prediction. Multiple models could be combined
|
||||
with `Ensemble` module.
|
||||
|
||||
`Interday Strategy` `Portfolio Generator` will take prediction scores as input and output the
|
||||
orders based on the current position to achieve the target portfolio.
|
||||
======================== ==============================================================================
|
||||
Name Description
|
||||
======================== ==============================================================================
|
||||
`Infrastructure` layer `Infrastructure` layer provides underlying support for Quant research.
|
||||
`DataServer` provides high-performance infrastructure for users to manage
|
||||
and retrieve raw data. `Trainer` provides flexible interface to control
|
||||
the training process of models which enable algorithms controlling the
|
||||
training process.
|
||||
|
||||
`Intraday Trading` `Order Executor` is responsible for executing orders output by
|
||||
`Interday Strategy` and returning the executed results.
|
||||
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
|
||||
`Information Extractor` extracts data for models. `Forecast Model` focuses
|
||||
on producing all kinds of forecast signals (e.g. _alpha_, risk) for other
|
||||
modules. With these signals `Portfolio Generator` will generate the target
|
||||
portfolio and produce orders to be executed by `Order Executor`.
|
||||
|
||||
`Analysis` Users could get a detailed analysis report of forecasting signals and portfolios
|
||||
in this part.
|
||||
====================== ==============================================================================
|
||||
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
|
||||
system. `Analyser` module will provide users detailed analysis reports of
|
||||
forecasting signals, portfolios and execution results
|
||||
======================== ==============================================================================
|
||||
|
||||
- The modules with hand-drawn style are under development and will be released in the future.
|
||||
- The modules with dashed borders are highly user-customizable and extendible.
|
||||
|
||||
@@ -49,18 +49,19 @@ To kown more about `prepare data`, please refer to `Data Preparation <../compone
|
||||
Auto Quant Research Workflow
|
||||
====================================
|
||||
|
||||
``Qlib`` provides a tool named ``Estimator`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
``Qlib`` provides a tool named ``qrun`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
|
||||
|
||||
- Quant Research Workflow:
|
||||
- Run ``Estimator`` with `estimator_config.yaml` as following.
|
||||
- Run ``qrun`` with a config file of the LightGBM model `workflow_config_lightgbm.yaml` as following.
|
||||
|
||||
.. code-block::
|
||||
|
||||
cd examples # Avoid running program under the directory contains `qlib`
|
||||
estimator -c estimator/estimator_config.yaml
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
|
||||
|
||||
- Estimator result
|
||||
The result of ``Estimator`` is as follows, which is also the result of ``Intraday Trading``. Please refer to `Intraday Trading <../component/backtest.html>`_. for more details about the result.
|
||||
- Workflow result
|
||||
The result of ``qrun`` is as follows, which is also the typical result of ``Forecast model(alpha)``. Please refer to `Intraday Trading <../component/backtest.html>`_. for more details about the result.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -77,17 +78,17 @@ Auto Quant Research Workflow
|
||||
max_drawdown -0.075024
|
||||
|
||||
|
||||
To know more about `Estimator`, please refer to `Estimator: Workflow Management <../component/estimator.html>`_.
|
||||
To know more about `workflow` and `qrun`, please refer to `Workflow: Workflow Management <../component/workflow.html>`_.
|
||||
|
||||
- Graphical Reports Analysis:
|
||||
- Run ``examples/estimator/analyze_from_estimator.ipynb`` with jupyter notebook
|
||||
Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/estimator/analyze_from_estimator.ipynb``.
|
||||
- Run ``examples/workflow_by_code.ipynb`` with jupyter notebook
|
||||
Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/workflow_by_code.ipynb``.
|
||||
- Graphical Reports
|
||||
Users can get graphical reports about the analysis, please refer to `Aanalysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
|
||||
Users can get graphical reports about the analysis, please refer to `Analysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
|
||||
|
||||
|
||||
|
||||
Custom Model Integration
|
||||
===============================================
|
||||
|
||||
``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. If users are interested in the custom model, please refer to `Custom Model Integration <../start/integration.html>`_.
|
||||
``Qlib`` provides a batch of models (such as ``lightGBM`` and ``MLP`` models) as examples of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. If users are interested in the custom model, please refer to `Custom Model Integration <../start/integration.html>`_.
|
||||
|
||||
@@ -23,16 +23,13 @@ Filter
|
||||
.. automodule:: qlib.data.filter
|
||||
:members:
|
||||
|
||||
Feature
|
||||
--------------------
|
||||
|
||||
Class
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
--------------------
|
||||
.. automodule:: qlib.data.base
|
||||
:members:
|
||||
|
||||
Operator
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
--------------------
|
||||
.. automodule:: qlib.data.ops
|
||||
:members:
|
||||
|
||||
@@ -56,19 +53,36 @@ Cache
|
||||
.. autoclass:: qlib.data.cache.DiskDatasetCache
|
||||
:members:
|
||||
|
||||
Dataset
|
||||
---------------
|
||||
|
||||
Dataset Class
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.__init__
|
||||
:members:
|
||||
|
||||
Data Loader
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.loader
|
||||
:members:
|
||||
|
||||
Data Handler
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.handler
|
||||
:members:
|
||||
|
||||
Processor
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.processor
|
||||
:members:
|
||||
|
||||
|
||||
Contrib
|
||||
====================
|
||||
|
||||
|
||||
Data Handler
|
||||
---------------
|
||||
.. automodule:: qlib.contrib.estimator.handler
|
||||
:members:
|
||||
|
||||
Model
|
||||
--------------------
|
||||
.. automodule:: qlib.contrib.model.base
|
||||
.. automodule:: qlib.model.base
|
||||
:members:
|
||||
|
||||
Strategy
|
||||
@@ -116,3 +130,26 @@ Report
|
||||
:members:
|
||||
|
||||
|
||||
Workflow
|
||||
====================
|
||||
|
||||
|
||||
Experiment Manager
|
||||
--------------------
|
||||
.. autoclass:: qlib.workflow.expm.ExpManager
|
||||
:members:
|
||||
|
||||
Experiment
|
||||
--------------------
|
||||
.. autoclass:: qlib.workflow.exp.Experiment
|
||||
:members:
|
||||
|
||||
Recorder
|
||||
--------------------
|
||||
.. autoclass:: qlib.workflow.recorder.Recorder
|
||||
:members:
|
||||
|
||||
Record Template
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.record_temp
|
||||
:members:
|
||||
@@ -1,4 +1,5 @@
|
||||
.. _getdata:
|
||||
|
||||
=============================
|
||||
Data Retrieval
|
||||
=============================
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
.. _initialization:
|
||||
|
||||
====================
|
||||
Qlib Initialization
|
||||
====================
|
||||
@@ -11,14 +12,16 @@ Initialization
|
||||
|
||||
Please follow the steps below to initialize ``Qlib``.
|
||||
|
||||
- Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>` for more information about customized dataset.
|
||||
Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>`_ for more information about customized dataset.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
|
||||
|
||||
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
|
||||
|
||||
|
||||
- Initialize Qlib before calling other APIs: run following code in python.
|
||||
Initialize Qlib before calling other APIs: run following code in python.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
@@ -58,3 +61,16 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
.. note::
|
||||
|
||||
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
|
||||
- `exp_manager`
|
||||
Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, if you want to set your tracking_uri to a <specific folder>, you can initialize qlib below
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager= {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "python_execution_path/mlruns",
|
||||
"default_exp_name": "Experiment",
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
.. _installation:
|
||||
|
||||
====================
|
||||
Installation
|
||||
====================
|
||||
|
||||
@@ -5,17 +5,17 @@ Custom Model Integration
|
||||
Introduction
|
||||
===================
|
||||
|
||||
``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``.
|
||||
``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are examples of ``Interday Model``. In addition to the default models ``Qlib`` provide, users can integrate their own custom models into ``Qlib``.
|
||||
|
||||
Users can integrate their own custom models according to the following steps.
|
||||
|
||||
- Define a custom model class, which should be a subclass of the `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_.
|
||||
- Define a custom model class, which should be a subclass of the `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_.
|
||||
- Write a configuration file that describes the path and parameters of the custom model.
|
||||
- Test the custom model.
|
||||
|
||||
Custom Model Class
|
||||
===========================
|
||||
The Custom models need to inherit `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ and override the methods in it.
|
||||
The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ and override the methods in it.
|
||||
|
||||
- Override the `__init__` method
|
||||
- ``Qlib`` passes the initialized parameters to the \_\_init\_\_ method.
|
||||
@@ -32,79 +32,77 @@ The Custom models need to inherit `qlib.contrib.model.base.Model <../reference/a
|
||||
|
||||
- Override the `fit` method
|
||||
- ``Qlib`` calls the fit method to train the model
|
||||
- The parameters must include training feature `x_train`, training label `y_train`, test feature `x_valid`, test label `y_valid` at least.
|
||||
- The parameters could include some optional parameters with default values, such as train weight `w_train`, test weight `w_valid` and `num_boost_round = 1000`.
|
||||
- The parameters must include training feature `dataset`.
|
||||
- The parameters could include some optional parameters with default values, such as `num_boost_round = 1000` for `GBDT`.
|
||||
- Code Example: In the following example, `num_boost_round = 1000` is an optional parameter.
|
||||
.. code-block:: Python
|
||||
|
||||
def fit(self, x_train:pd.DataFrame, y_train:pd.DataFrame, x_valid:pd.DataFrame, y_valid:pd.DataFrame,
|
||||
w_train:pd.DataFrame = None, w_valid:pd.DataFrame = None, num_boost_round = 1000, **kwargs):
|
||||
def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
|
||||
|
||||
# prepare dataset for lgb training and evaluation
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError('LightGBM doesn\'t support multi-label training')
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
w_train_weight = None if w_train is None else w_train.values
|
||||
w_valid_weight = None if w_valid is None else w_valid.values
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
|
||||
self._model = lgb.train(
|
||||
self._params,
|
||||
dtrain,
|
||||
# fit the model
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=['train', 'valid'],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
- Override the `predict` method
|
||||
- The parameters include the test features.
|
||||
- The parameters must include training feature `dataset`, which will be userd to get the test dataset.
|
||||
- Return the `prediction score`.
|
||||
- Please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_ for the parameter types of the fit method.
|
||||
- Code Example: In the following example, users need to use dnn to predict the label(such as `preds`) of test data `x_test` and return it.
|
||||
- Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
|
||||
- Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
|
||||
.. code-block:: Python
|
||||
|
||||
def predict(self, x_test:pd.DataFrame, **kwargs)-> numpy.ndarray:
|
||||
if self._model is None:
|
||||
raise ValueError('model is not fitted yet!')
|
||||
return self._model.predict(x_test.values)
|
||||
def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
- Override the `save` method & `load` method
|
||||
- The `save` method parameter includes the a `filename` that represents an absolute path, user need to save model into the path.
|
||||
- The `load` method parameter includes the a `buffer` read from the `filename` passed in the `save` method, users need to load model from the `buffer`.
|
||||
- Code Example:
|
||||
- Override the `finetune` method
|
||||
- The parameters must include training feature `dataset`.
|
||||
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
|
||||
.. code-block:: Python
|
||||
|
||||
def save(self, filename):
|
||||
if self._model is None:
|
||||
raise ValueError('model is not fitted yet!')
|
||||
self._model.save_model(filename)
|
||||
|
||||
def load(self, buffer):
|
||||
self._model = lgb.Booster(params={'model_str': buffer.decode('utf-8')})
|
||||
|
||||
.. Without tuner, this part will not be used
|
||||
.. - Override the `score` method(This step is optional)
|
||||
.. - The parameters include the test features and test labels.
|
||||
.. - Return the evaluation score of the model. It's recommended to adopt the loss between labels and `prediction score`.
|
||||
.. - Code Example: In the following example, users need to calculate the weighted loss with test data `x_test`, test label `y_test` and the weight `w_test`.
|
||||
.. .. code-block:: Python
|
||||
..
|
||||
.. def score(self, x_test:pd.Dataframe, y_test:pd.Dataframe, w_test:pd.DataFrame = None) -> float:
|
||||
.. # Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
.. x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
|
||||
.. preds = self.predict(x_test)
|
||||
.. w_test_weight = None if w_test is None else w_test.values
|
||||
.. scorer = mean_squared_error if self.loss_type == 'mse' else roc_auc_score
|
||||
.. return scorer(y_test.values, preds, sample_weight=w_test_weight)
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
|
||||
Configuration File
|
||||
=======================
|
||||
|
||||
The configuration file is described in detail in the `estimator <../component/estimator.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file.
|
||||
The configuration file is described in detail in the `Workflow <../component/workflow.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file. The configuration describes which models to use and how we can initialize it.
|
||||
|
||||
- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`.
|
||||
|
||||
@@ -124,23 +122,23 @@ The configuration file is described in detail in the `estimator <../component/es
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
|
||||
Users could find configuration file of the baseline of the ``Model`` in ``qlib/examples/estimator/estimator_config.yaml`` and ``qlib/examples/estimator/estimator_config_dnn.yaml``
|
||||
Users could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder.
|
||||
|
||||
Model Testing
|
||||
=====================
|
||||
Assuming that the configuration file is ``examples/estimator/estimator_config.yaml``, users can run the following command to test the custom model:
|
||||
Assuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml``, users can run the following command to test the custom model:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd examples # Avoid running program under the directory contains `qlib`
|
||||
estimator -c estimator/estimator_config.yaml
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
|
||||
.. note:: ``estimator`` is a built-in command of ``Qlib``.
|
||||
.. note:: ``qrun`` is a built-in command of ``Qlib``.
|
||||
|
||||
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``.
|
||||
|
||||
|
||||
Reference
|
||||
=====================
|
||||
|
||||
To know more about ``Interday Model``, please refer to `Interday Model: Model Training & Prediction <../component/model.html>`_ and `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
|
||||
To know more about ``Interday Model``, please refer to `Interday Model: Model Training & Prediction <../component/model.html>`_ and `Model API <../reference/api.html#module-qlib.model.base>`_.
|
||||
|
||||
8
examples/benchmarks/ALSTM/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# ALSTM
|
||||
|
||||
- ALSTM contains a temporal attentive aggregation layer based on normal LSTM.
|
||||
|
||||
- Paper: A dual-stage attention-based recurrent neural network for time series prediction.
|
||||
|
||||
[https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)
|
||||
|
||||
4
examples/benchmarks/ALSTM/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
83
examples/benchmarks/ALSTM/workflow_config_alstm.yaml
Normal file
@@ -0,0 +1,83 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: ALSTM
|
||||
module_path: qlib.contrib.model.pytorch_alstm
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
rnn_type: GRU
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
3
examples/benchmarks/CatBoost/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# CatBoost
|
||||
* Code: [https://github.com/catboost/catboost](https://github.com/catboost/catboost)
|
||||
* Paper: CatBoost: unbiased boosting with categorical features. [https://proceedings.neurips.cc/paper/2018/file/14491b756b3a51daac41c24863285549-Paper.pdf](https://proceedings.neurips.cc/paper/2018/file/14491b756b3a51daac41c24863285549-Paper.pdf).
|
||||
3
examples/benchmarks/CatBoost/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
catboost==0.24.3
|
||||
64
examples/benchmarks/CatBoost/workflow_config_catboost.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: CatBoostModel
|
||||
module_path: qlib.contrib.model.catboost_model
|
||||
kwargs:
|
||||
loss: RMSE
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
max_depth: 6
|
||||
num_leaves: 100
|
||||
thread_count: 20
|
||||
grow_policy: Lossguide
|
||||
bootstrap_type: Poisson
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
5
examples/benchmarks/GATs/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# GATs
|
||||
* Graph Attention Networks(GATs) leverage masked self-attentional layers on graph-structured data. The nodes in stacked layers have different weights and they are able to attend over their
|
||||
neighborhoods’ features, without requiring any kind of costly matrix operation (such as inversion) or depending on knowing the graph structure upfront.
|
||||
* This code used in Qlib is implemented with PyTorch by ourselves.
|
||||
* Paper: Graph Attention Networks https://arxiv.org/pdf/1710.10903.pdf
|
||||
4
examples/benchmarks/GATs/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
77
examples/benchmarks/GATs/workflow_config_gats.yaml
Normal file
@@ -0,0 +1,77 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GATs
|
||||
module_path: qlib.contrib.model.pytorch_gats
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.7
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
BIN
examples/benchmarks/GRU/model_gru_csi300.pkl
Normal file
4
examples/benchmarks/GRU/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
82
examples/benchmarks/GRU/workflow_config_gru.yaml
Normal file
@@ -0,0 +1,82 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GRU
|
||||
module_path: qlib.contrib.model.pytorch_gru
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
BIN
examples/benchmarks/LSTM/model_lstm_csi300.pkl
Normal file
4
examples/benchmarks/LSTM/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
82
examples/benchmarks/LSTM/workflow_config_lstm.yaml
Normal file
@@ -0,0 +1,82 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LSTM
|
||||
module_path: qlib.contrib.model.pytorch_lstm
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
4
examples/benchmarks/LightGBM/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# LightGBM
|
||||
* Code: [https://github.com/microsoft/LightGBM](https://github.com/microsoft/LightGBM)
|
||||
* Paper: LightGBM: A Highly Efficient Gradient Boosting
|
||||
Decision Tree. [https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf).
|
||||
3
examples/benchmarks/LightGBM/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
65
examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
3
examples/benchmarks/Linear/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
numpy>=1.17.4
|
||||
pandas>=1.0.1
|
||||
scikit-learn>=0.23.1
|
||||
71
examples/benchmarks/Linear/workflow_config_linear.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LinearModel
|
||||
module_path: qlib.contrib.model.linear
|
||||
kwargs:
|
||||
estimator: ols
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: True
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
4
examples/benchmarks/MLP/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
93
examples/benchmarks/MLP/workflow_config_mlp.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "CSZFillna",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
}
|
||||
]
|
||||
learn_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "DropnaProcessor",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
},
|
||||
"DropnaLabel",
|
||||
{
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {"fields_group": "label"}
|
||||
}
|
||||
]
|
||||
process_type: "independent"
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DNNModelPytorch
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 157
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: adam
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
3
examples/benchmarks/SFM/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# State-Frequency-Memory
|
||||
- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
|
||||
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.)
|
||||
4
examples/benchmarks/SFM/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
85
examples/benchmarks/SFM/workflow_config_sfm.yaml
Normal file
@@ -0,0 +1,85 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: SFM
|
||||
module_path: qlib.contrib.model.pytorch_sfm
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
output_dim: 32
|
||||
freq_dim: 25
|
||||
dropout_W: 0.5
|
||||
dropout_U: 0.5
|
||||
n_epochs: 20
|
||||
lr: 1e-3
|
||||
batch_size: 1600
|
||||
early_stop: 20
|
||||
eval_steps: 5
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
GPU: 1
|
||||
seed: 710
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
14
examples/benchmarks/TFT/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Temporal Fusion Transformers Benchmark
|
||||
## Source
|
||||
**Reference**: Lim, Bryan, et al. "Temporal fusion transformers for interpretable multi-horizon time series forecasting." arXiv preprint arXiv:1912.09363 (2019).
|
||||
|
||||
**GitHub**: https://github.com/google-research/google-research/tree/master/tft
|
||||
|
||||
## Run the Workflow
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
|
||||
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
14
examples/benchmarks/TFT/data_formatters/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
223
examples/benchmarks/TFT/data_formatters/base.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Default data formatting functions for experiments.
|
||||
|
||||
For new datasets, inherit form GenericDataFormatter and implement
|
||||
all abstract functions.
|
||||
|
||||
These dataset-specific methods:
|
||||
1) Define the column and input types for tabular dataframes used by model
|
||||
2) Perform the necessary input feature engineering & normalisation steps
|
||||
3) Reverts the normalisation for predictions
|
||||
4) Are responsible for train, validation and test splits
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import abc
|
||||
import enum
|
||||
|
||||
|
||||
# Type defintions
|
||||
class DataTypes(enum.IntEnum):
|
||||
"""Defines numerical types of each column."""
|
||||
|
||||
REAL_VALUED = 0
|
||||
CATEGORICAL = 1
|
||||
DATE = 2
|
||||
|
||||
|
||||
class InputTypes(enum.IntEnum):
|
||||
"""Defines input types of each column."""
|
||||
|
||||
TARGET = 0
|
||||
OBSERVED_INPUT = 1
|
||||
KNOWN_INPUT = 2
|
||||
STATIC_INPUT = 3
|
||||
ID = 4 # Single column used as an entity identifier
|
||||
TIME = 5 # Single column exclusively used as a time index
|
||||
|
||||
|
||||
class GenericDataFormatter(abc.ABC):
|
||||
"""Abstract base class for all data formatters.
|
||||
|
||||
User can implement the abstract methods below to perform dataset-specific
|
||||
manipulations.
|
||||
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_scalers(self, df):
|
||||
"""Calibrates scalers using the data supplied."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transform_inputs(self, df):
|
||||
"""Performs feature transformation."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_predictions(self, df):
|
||||
"""Reverts any normalisation to give predictions in original scale."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def split_data(self, df):
|
||||
"""Performs the default train, validation and test splits."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _column_definition(self):
|
||||
"""Defines order, input type and data type of each column."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_fixed_params(self):
|
||||
"""Defines the fixed parameters used by the model for training.
|
||||
|
||||
Requires the following keys:
|
||||
'total_time_steps': Defines the total number of time steps used by TFT
|
||||
'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
|
||||
'num_epochs': Maximum number of epochs for training
|
||||
'early_stopping_patience': Early stopping param for keras
|
||||
'multiprocessing_workers': # of cpus for data processing
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary of fixed parameters, e.g.:
|
||||
|
||||
fixed_params = {
|
||||
'total_time_steps': 252 + 5,
|
||||
'num_encoder_steps': 252,
|
||||
'num_epochs': 100,
|
||||
'early_stopping_patience': 5,
|
||||
'multiprocessing_workers': 5,
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Shared functions across data-formatters
|
||||
@property
|
||||
def num_classes_per_cat_input(self):
|
||||
"""Returns number of categories per relevant input.
|
||||
|
||||
This is seqeuently required for keras embedding layers.
|
||||
"""
|
||||
return self._num_classes_per_cat_input
|
||||
|
||||
def get_num_samples_for_calibration(self):
|
||||
"""Gets the default number of training and validation samples.
|
||||
|
||||
Use to sub-sample the data for network calibration and a value of -1 uses
|
||||
all available samples.
|
||||
|
||||
Returns:
|
||||
Tuple of (training samples, validation samples)
|
||||
"""
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
# Sanity checks first.
|
||||
# Ensure only one ID and time column exist
|
||||
def _check_single_column(input_type):
|
||||
|
||||
length = len([tup for tup in column_definition if tup[2] == input_type])
|
||||
|
||||
if length != 1:
|
||||
raise ValueError("Illegal number of inputs ({}) of type {}".format(length, input_type))
|
||||
|
||||
_check_single_column(InputTypes.ID)
|
||||
_check_single_column(InputTypes.TIME)
|
||||
|
||||
identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
|
||||
time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
|
||||
real_inputs = [
|
||||
tup
|
||||
for tup in column_definition
|
||||
if tup[1] == DataTypes.REAL_VALUED and tup[2] not in {InputTypes.ID, InputTypes.TIME}
|
||||
]
|
||||
categorical_inputs = [
|
||||
tup
|
||||
for tup in column_definition
|
||||
if tup[1] == DataTypes.CATEGORICAL and tup[2] not in {InputTypes.ID, InputTypes.TIME}
|
||||
]
|
||||
|
||||
return identifier + time + real_inputs + categorical_inputs
|
||||
|
||||
def _get_input_columns(self):
|
||||
"""Returns names of all input columns."""
|
||||
return [tup[0] for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
|
||||
|
||||
def _get_tft_input_indices(self):
|
||||
"""Returns the relevant indexes and input sizes required by TFT."""
|
||||
|
||||
# Functions
|
||||
def _extract_tuples_from_data_type(data_type, defn):
|
||||
return [tup for tup in defn if tup[1] == data_type and tup[2] not in {InputTypes.ID, InputTypes.TIME}]
|
||||
|
||||
def _get_locations(input_types, defn):
|
||||
return [i for i, tup in enumerate(defn) if tup[2] in input_types]
|
||||
|
||||
# Start extraction
|
||||
column_definition = [
|
||||
tup for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}
|
||||
]
|
||||
|
||||
categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, column_definition)
|
||||
real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, column_definition)
|
||||
|
||||
locations = {
|
||||
"input_size": len(self._get_input_columns()),
|
||||
"output_size": len(_get_locations({InputTypes.TARGET}, column_definition)),
|
||||
"category_counts": self.num_classes_per_cat_input,
|
||||
"input_obs_loc": _get_locations({InputTypes.TARGET}, column_definition),
|
||||
"static_input_loc": _get_locations({InputTypes.STATIC_INPUT}, column_definition),
|
||||
"known_regular_inputs": _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs),
|
||||
"known_categorical_inputs": _get_locations(
|
||||
{InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, categorical_inputs
|
||||
),
|
||||
}
|
||||
|
||||
return locations
|
||||
|
||||
def get_experiment_params(self):
|
||||
"""Returns fixed model parameters for experiments."""
|
||||
|
||||
required_keys = [
|
||||
"total_time_steps",
|
||||
"num_encoder_steps",
|
||||
"num_epochs",
|
||||
"early_stopping_patience",
|
||||
"multiprocessing_workers",
|
||||
]
|
||||
|
||||
fixed_params = self.get_fixed_params()
|
||||
|
||||
for k in required_keys:
|
||||
if k not in fixed_params:
|
||||
raise ValueError("Field {}".format(k) + " missing from fixed parameter definitions!")
|
||||
|
||||
fixed_params["column_definition"] = self.get_column_definition()
|
||||
|
||||
fixed_params.update(self._get_tft_input_indices())
|
||||
|
||||
return fixed_params
|
||||
219
examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Custom formatting functions for Alpha158 dataset.
|
||||
|
||||
Defines dataset specific column definitions and data transformations.
|
||||
"""
|
||||
|
||||
import data_formatters.base
|
||||
import libs.utils as utils
|
||||
import sklearn.preprocessing
|
||||
|
||||
GenericDataFormatter = data_formatters.base.GenericDataFormatter
|
||||
DataTypes = data_formatters.base.DataTypes
|
||||
InputTypes = data_formatters.base.InputTypes
|
||||
|
||||
|
||||
class Alpha158Formatter(GenericDataFormatter):
|
||||
"""Defines and formats data for the Alpha158 dataset.
|
||||
|
||||
Attributes:
|
||||
column_definition: Defines input and data type of column used in the
|
||||
experiment.
|
||||
identifiers: Entity identifiers used in experiments.
|
||||
"""
|
||||
|
||||
_column_definition = [
|
||||
("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
|
||||
("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
|
||||
("date", DataTypes.DATE, InputTypes.TIME),
|
||||
("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
|
||||
("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
|
||||
# Selected 10 features
|
||||
("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialises formatter."""
|
||||
|
||||
self.identifiers = None
|
||||
self._real_scalers = None
|
||||
self._cat_scalers = None
|
||||
self._target_scaler = None
|
||||
self._num_classes_per_cat_input = None
|
||||
|
||||
def split_data(self, df, valid_boundary=2016, test_boundary=2018):
|
||||
"""Splits data frame into training-validation-test data frames.
|
||||
|
||||
This also calibrates scaling object, and transforms data for each split.
|
||||
|
||||
Args:
|
||||
df: Source data frame to split.
|
||||
valid_boundary: Starting year for validation data
|
||||
test_boundary: Starting year for test data
|
||||
|
||||
Returns:
|
||||
Tuple of transformed (train, valid, test) data.
|
||||
"""
|
||||
|
||||
print("Formatting train-valid-test splits.")
|
||||
|
||||
index = df["year"]
|
||||
train = df.loc[index < valid_boundary]
|
||||
valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
|
||||
test = df.loc[index >= test_boundary]
|
||||
|
||||
self.set_scalers(train)
|
||||
|
||||
return (self.transform_inputs(data) for data in [train, valid, test])
|
||||
|
||||
def set_scalers(self, df):
|
||||
"""Calibrates scalers using the data supplied.
|
||||
|
||||
Args:
|
||||
df: Data to use to calibrate scalers.
|
||||
"""
|
||||
print("Setting scalers with training data...")
|
||||
|
||||
column_definitions = self.get_column_definition()
|
||||
id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
|
||||
target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
|
||||
|
||||
# Extract identifiers in case required
|
||||
self.identifiers = list(df[id_column].unique())
|
||||
|
||||
# Format real scalers
|
||||
real_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
data = df[real_inputs].values
|
||||
self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
|
||||
self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
|
||||
df[[target_column]].values
|
||||
) # used for predictions
|
||||
|
||||
# Format categorical scalers
|
||||
categorical_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
categorical_scalers = {}
|
||||
num_classes = []
|
||||
for col in categorical_inputs:
|
||||
# Set all to str so that we don't have mixed integer/string columns
|
||||
srs = df[col].apply(str)
|
||||
categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
|
||||
num_classes.append(srs.nunique())
|
||||
|
||||
# Set categorical scaler outputs
|
||||
self._cat_scalers = categorical_scalers
|
||||
self._num_classes_per_cat_input = num_classes
|
||||
|
||||
def transform_inputs(self, df):
|
||||
"""Performs feature transformations.
|
||||
|
||||
This includes both feature engineering, preprocessing and normalisation.
|
||||
|
||||
Args:
|
||||
df: Data frame to transform.
|
||||
|
||||
Returns:
|
||||
Transformed data frame.
|
||||
|
||||
"""
|
||||
output = df.copy()
|
||||
|
||||
if self._real_scalers is None and self._cat_scalers is None:
|
||||
raise ValueError("Scalers have not been set!")
|
||||
|
||||
column_definitions = self.get_column_definition()
|
||||
|
||||
real_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
categorical_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
# Format real inputs
|
||||
output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
|
||||
|
||||
# Format categorical inputs
|
||||
for col in categorical_inputs:
|
||||
string_df = df[col].apply(str)
|
||||
output[col] = self._cat_scalers[col].transform(string_df)
|
||||
|
||||
return output
|
||||
|
||||
def format_predictions(self, predictions):
|
||||
"""Reverts any normalisation to give predictions in original scale.
|
||||
|
||||
Args:
|
||||
predictions: Dataframe of model predictions.
|
||||
|
||||
Returns:
|
||||
Data frame of unnormalised predictions.
|
||||
"""
|
||||
output = predictions.copy()
|
||||
|
||||
column_names = predictions.columns
|
||||
|
||||
for col in column_names:
|
||||
if col not in {"forecast_time", "identifier"}:
|
||||
output[col] = self._target_scaler.inverse_transform(predictions[col])
|
||||
|
||||
return output
|
||||
|
||||
# Default params
|
||||
def get_fixed_params(self):
|
||||
"""Returns fixed model parameters for experiments."""
|
||||
|
||||
fixed_params = {
|
||||
"total_time_steps": 6 + 6,
|
||||
"num_encoder_steps": 6,
|
||||
"num_epochs": 100,
|
||||
"early_stopping_patience": 10,
|
||||
"multiprocessing_workers": 5,
|
||||
}
|
||||
|
||||
return fixed_params
|
||||
|
||||
def get_default_model_params(self):
|
||||
"""Returns default optimised model parameters."""
|
||||
|
||||
model_params = {
|
||||
"dropout_rate": 0.4,
|
||||
"hidden_layer_size": 16,
|
||||
"learning_rate": 0.0001,
|
||||
"minibatch_size": 128,
|
||||
"max_gradient_norm": 0.0135,
|
||||
"num_heads": 1,
|
||||
"stack_size": 1,
|
||||
}
|
||||
|
||||
return model_params
|
||||
14
examples/benchmarks/TFT/expt_settings/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
95
examples/benchmarks/TFT/expt_settings/configs.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Default configs for TFT experiments.
|
||||
|
||||
Contains the default output paths for data, serialised models and predictions
|
||||
for the main experiments used in the publication.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import data_formatters.qlib_Alpha158
|
||||
|
||||
|
||||
class ExperimentConfig(object):
|
||||
"""Defines experiment configs and paths to outputs.
|
||||
|
||||
Attributes:
|
||||
root_folder: Root folder to contain all experimental outputs.
|
||||
experiment: Name of experiment to run.
|
||||
data_folder: Folder to store data for experiment.
|
||||
model_folder: Folder to store serialised models.
|
||||
results_folder: Folder to store results.
|
||||
data_csv_path: Path to primary data csv file used in experiment.
|
||||
hyperparam_iterations: Default number of random search iterations for
|
||||
experiment.
|
||||
"""
|
||||
|
||||
default_experiments = ["Alpha158"]
|
||||
|
||||
def __init__(self, experiment="volatility", root_folder=None):
|
||||
"""Creates configs based on default experiment chosen.
|
||||
|
||||
Args:
|
||||
experiment: Name of experiment.
|
||||
root_folder: Root folder to save all outputs of training.
|
||||
"""
|
||||
|
||||
if experiment not in self.default_experiments:
|
||||
raise ValueError("Unrecognised experiment={}".format(experiment))
|
||||
|
||||
# Defines all relevant paths
|
||||
if root_folder is None:
|
||||
root_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "outputs")
|
||||
print("Using root folder {}".format(root_folder))
|
||||
|
||||
self.root_folder = root_folder
|
||||
self.experiment = experiment
|
||||
self.data_folder = os.path.join(root_folder, "data", experiment)
|
||||
self.model_folder = os.path.join(root_folder, "saved_models", experiment)
|
||||
self.results_folder = os.path.join(root_folder, "results", experiment)
|
||||
|
||||
# Creates folders if they don't exist
|
||||
for relevant_directory in [self.root_folder, self.data_folder, self.model_folder, self.results_folder]:
|
||||
if not os.path.exists(relevant_directory):
|
||||
os.makedirs(relevant_directory)
|
||||
|
||||
@property
|
||||
def data_csv_path(self):
|
||||
csv_map = {
|
||||
"Alpha158": "Alpha158.csv",
|
||||
}
|
||||
|
||||
return os.path.join(self.data_folder, csv_map[self.experiment])
|
||||
|
||||
@property
|
||||
def hyperparam_iterations(self):
|
||||
|
||||
return 240 if self.experiment == "volatility" else 60
|
||||
|
||||
def make_data_formatter(self):
|
||||
"""Gets a data formatter object for experiment.
|
||||
|
||||
Returns:
|
||||
Default DataFormatter per experiment.
|
||||
"""
|
||||
|
||||
data_formatter_class = {
|
||||
"Alpha158": data_formatters.qlib_Alpha158.Alpha158Formatter,
|
||||
}
|
||||
|
||||
return data_formatter_class[self.experiment]()
|
||||
14
examples/benchmarks/TFT/libs/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
430
examples/benchmarks/TFT/libs/hyperparam_opt.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Classes used for hyperparameter optimisation.
|
||||
|
||||
Two main classes exist:
|
||||
1) HyperparamOptManager used for optimisation on a single machine/GPU.
|
||||
2) DistributedHyperparamOptManager for multiple GPUs on different machines.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
import libs.utils as utils
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
Deque = collections.deque
|
||||
|
||||
|
||||
class HyperparamOptManager:
|
||||
"""Manages hyperparameter optimisation using random search for a single GPU.
|
||||
|
||||
Attributes:
|
||||
param_ranges: Discrete hyperparameter range for random search.
|
||||
results: Dataframe of validation results.
|
||||
fixed_params: Fixed model parameters per experiment.
|
||||
saved_params: Dataframe of parameters trained.
|
||||
best_score: Minimum validation loss observed thus far.
|
||||
optimal_name: Key to best configuration.
|
||||
hyperparam_folder: Where to save optimisation outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, param_ranges, fixed_params, model_folder, override_w_fixed_params=True):
|
||||
"""Instantiates model.
|
||||
|
||||
Args:
|
||||
param_ranges: Discrete hyperparameter range for random search.
|
||||
fixed_params: Fixed model parameters per experiment.
|
||||
model_folder: Folder to store optimisation artifacts.
|
||||
override_w_fixed_params: Whether to override serialsed fixed model
|
||||
parameters with new supplied values.
|
||||
"""
|
||||
|
||||
self.param_ranges = param_ranges
|
||||
|
||||
self._max_tries = 1000
|
||||
self.results = pd.DataFrame()
|
||||
self.fixed_params = fixed_params
|
||||
self.saved_params = pd.DataFrame()
|
||||
|
||||
self.best_score = np.Inf
|
||||
self.optimal_name = ""
|
||||
|
||||
# Setup
|
||||
# Create folder for saving if its not there
|
||||
self.hyperparam_folder = model_folder
|
||||
utils.create_folder_if_not_exist(self.hyperparam_folder)
|
||||
|
||||
self._override_w_fixed_params = override_w_fixed_params
|
||||
|
||||
def load_results(self):
|
||||
"""Loads results from previous hyperparameter optimisation.
|
||||
|
||||
Returns:
|
||||
A boolean indicating if previous results can be loaded.
|
||||
"""
|
||||
print("Loading results from", self.hyperparam_folder)
|
||||
|
||||
results_file = os.path.join(self.hyperparam_folder, "results.csv")
|
||||
params_file = os.path.join(self.hyperparam_folder, "params.csv")
|
||||
|
||||
if os.path.exists(results_file) and os.path.exists(params_file):
|
||||
|
||||
self.results = pd.read_csv(results_file, index_col=0)
|
||||
self.saved_params = pd.read_csv(params_file, index_col=0)
|
||||
|
||||
if not self.results.empty:
|
||||
self.results.at["loss"] = self.results.loc["loss"].apply(float)
|
||||
self.best_score = self.results.loc["loss"].min()
|
||||
|
||||
is_optimal = self.results.loc["loss"] == self.best_score
|
||||
self.optimal_name = self.results.T[is_optimal].index[0]
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_params_from_name(self, name):
|
||||
"""Returns previously saved parameters given a key."""
|
||||
params = self.saved_params
|
||||
|
||||
selected_params = dict(params[name])
|
||||
|
||||
if self._override_w_fixed_params:
|
||||
for k in self.fixed_params:
|
||||
selected_params[k] = self.fixed_params[k]
|
||||
|
||||
return selected_params
|
||||
|
||||
def get_best_params(self):
|
||||
"""Returns the optimal hyperparameters thus far."""
|
||||
|
||||
optimal_name = self.optimal_name
|
||||
|
||||
return self._get_params_from_name(optimal_name)
|
||||
|
||||
def clear(self):
|
||||
"""Clears all previous results and saved parameters."""
|
||||
shutil.rmtree(self.hyperparam_folder)
|
||||
os.makedirs(self.hyperparam_folder)
|
||||
self.results = pd.DataFrame()
|
||||
self.saved_params = pd.DataFrame()
|
||||
|
||||
def _check_params(self, params):
|
||||
"""Checks that parameter map is properly defined."""
|
||||
|
||||
valid_fields = list(self.param_ranges.keys()) + list(self.fixed_params.keys())
|
||||
invalid_fields = [k for k in params if k not in valid_fields]
|
||||
missing_fields = [k for k in valid_fields if k not in params]
|
||||
|
||||
if invalid_fields:
|
||||
raise ValueError("Invalid Fields Found {} - Valid ones are {}".format(invalid_fields, valid_fields))
|
||||
if missing_fields:
|
||||
raise ValueError("Missing Fields Found {} - Valid ones are {}".format(missing_fields, valid_fields))
|
||||
|
||||
def _get_name(self, params):
|
||||
"""Returns a unique key for the supplied set of params."""
|
||||
|
||||
self._check_params(params)
|
||||
|
||||
fields = list(params.keys())
|
||||
fields.sort()
|
||||
|
||||
return "_".join([str(params[k]) for k in fields])
|
||||
|
||||
def get_next_parameters(self, ranges_to_skip=None):
|
||||
"""Returns the next set of parameters to optimise.
|
||||
|
||||
Args:
|
||||
ranges_to_skip: Explicitly defines a set of keys to skip.
|
||||
"""
|
||||
if ranges_to_skip is None:
|
||||
ranges_to_skip = set(self.results.index)
|
||||
|
||||
if not isinstance(self.param_ranges, dict):
|
||||
raise ValueError("Only works for random search!")
|
||||
|
||||
param_range_keys = list(self.param_ranges.keys())
|
||||
param_range_keys.sort()
|
||||
|
||||
def _get_next():
|
||||
"""Returns next hyperparameter set per try."""
|
||||
|
||||
parameters = {k: np.random.choice(self.param_ranges[k]) for k in param_range_keys}
|
||||
|
||||
# Adds fixed params
|
||||
for k in self.fixed_params:
|
||||
parameters[k] = self.fixed_params[k]
|
||||
|
||||
return parameters
|
||||
|
||||
for _ in range(self._max_tries):
|
||||
|
||||
parameters = _get_next()
|
||||
name = self._get_name(parameters)
|
||||
|
||||
if name not in ranges_to_skip:
|
||||
return parameters
|
||||
|
||||
raise ValueError("Exceeded max number of hyperparameter searches!!")
|
||||
|
||||
def update_score(self, parameters, loss, model, info=""):
|
||||
"""Updates the results from last optimisation run.
|
||||
|
||||
Args:
|
||||
parameters: Hyperparameters used in optimisation.
|
||||
loss: Validation loss obtained.
|
||||
model: Model to serialised if required.
|
||||
info: Any ancillary information to tag on to results.
|
||||
|
||||
Returns:
|
||||
Boolean flag indicating if the model is the best seen so far.
|
||||
"""
|
||||
|
||||
if np.isnan(loss):
|
||||
loss = np.Inf
|
||||
|
||||
if not os.path.isdir(self.hyperparam_folder):
|
||||
os.makedirs(self.hyperparam_folder)
|
||||
|
||||
name = self._get_name(parameters)
|
||||
|
||||
is_optimal = self.results.empty or loss < self.best_score
|
||||
|
||||
# save the first model
|
||||
if is_optimal:
|
||||
# Try saving first, before updating info
|
||||
if model is not None:
|
||||
print("Optimal model found, updating")
|
||||
model.save(self.hyperparam_folder)
|
||||
self.best_score = loss
|
||||
self.optimal_name = name
|
||||
|
||||
self.results[name] = pd.Series({"loss": loss, "info": info})
|
||||
self.saved_params[name] = pd.Series(parameters)
|
||||
|
||||
self.results.to_csv(os.path.join(self.hyperparam_folder, "results.csv"))
|
||||
self.saved_params.to_csv(os.path.join(self.hyperparam_folder, "params.csv"))
|
||||
|
||||
return is_optimal
|
||||
|
||||
|
||||
class DistributedHyperparamOptManager(HyperparamOptManager):
|
||||
"""Manages distributed hyperparameter optimisation across many gpus."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
param_ranges,
|
||||
fixed_params,
|
||||
root_model_folder,
|
||||
worker_number,
|
||||
search_iterations=1000,
|
||||
num_iterations_per_worker=5,
|
||||
clear_serialised_params=False,
|
||||
):
|
||||
"""Instantiates optimisation manager.
|
||||
|
||||
This hyperparameter optimisation pre-generates #search_iterations
|
||||
hyperparameter combinations and serialises them
|
||||
at the start. At runtime, each worker goes through their own set of
|
||||
parameter ranges. The pregeneration
|
||||
allows for multiple workers to run in parallel on different machines without
|
||||
resulting in parameter overlaps.
|
||||
|
||||
Args:
|
||||
param_ranges: Discrete hyperparameter range for random search.
|
||||
fixed_params: Fixed model parameters per experiment.
|
||||
root_model_folder: Folder to store optimisation artifacts.
|
||||
worker_number: Worker index definining which set of hyperparameters to
|
||||
test.
|
||||
search_iterations: Maximum numer of random search iterations.
|
||||
num_iterations_per_worker: How many iterations are handled per worker.
|
||||
clear_serialised_params: Whether to regenerate hyperparameter
|
||||
combinations.
|
||||
"""
|
||||
|
||||
max_workers = int(np.ceil(search_iterations / num_iterations_per_worker))
|
||||
|
||||
# Sanity checks
|
||||
if worker_number > max_workers:
|
||||
raise ValueError(
|
||||
"Worker number ({}) cannot be larger than the total number of workers!".format(max_workers)
|
||||
)
|
||||
if worker_number > search_iterations:
|
||||
raise ValueError(
|
||||
"Worker number ({}) cannot be larger than the max search iterations ({})!".format(
|
||||
worker_number, search_iterations
|
||||
)
|
||||
)
|
||||
|
||||
print("*** Creating hyperparameter manager for worker {} ***".format(worker_number))
|
||||
|
||||
hyperparam_folder = os.path.join(root_model_folder, str(worker_number))
|
||||
super().__init__(param_ranges, fixed_params, hyperparam_folder, override_w_fixed_params=True)
|
||||
|
||||
serialised_ranges_folder = os.path.join(root_model_folder, "hyperparams")
|
||||
if clear_serialised_params:
|
||||
print("Regenerating hyperparameter list")
|
||||
if os.path.exists(serialised_ranges_folder):
|
||||
shutil.rmtree(serialised_ranges_folder)
|
||||
|
||||
utils.create_folder_if_not_exist(serialised_ranges_folder)
|
||||
|
||||
self.serialised_ranges_path = os.path.join(serialised_ranges_folder, "ranges_{}.csv".format(search_iterations))
|
||||
self.hyperparam_folder = hyperparam_folder # override
|
||||
self.worker_num = worker_number
|
||||
self.total_search_iterations = search_iterations
|
||||
self.num_iterations_per_worker = num_iterations_per_worker
|
||||
self.global_hyperparam_df = self.load_serialised_hyperparam_df()
|
||||
self.worker_search_queue = self._get_worker_search_queue()
|
||||
|
||||
@property
|
||||
def optimisation_completed(self):
|
||||
return False if self.worker_search_queue else True
|
||||
|
||||
def get_next_parameters(self):
|
||||
"""Returns next dictionary of hyperparameters to optimise."""
|
||||
param_name = self.worker_search_queue.pop()
|
||||
|
||||
params = self.global_hyperparam_df.loc[param_name, :].to_dict()
|
||||
|
||||
# Always override!
|
||||
for k in self.fixed_params:
|
||||
print("Overriding saved {}: {}".format(k, self.fixed_params[k]))
|
||||
|
||||
params[k] = self.fixed_params[k]
|
||||
|
||||
return params
|
||||
|
||||
def load_serialised_hyperparam_df(self):
|
||||
"""Loads serialsed hyperparameter ranges from file.
|
||||
|
||||
Returns:
|
||||
DataFrame containing hyperparameter combinations.
|
||||
"""
|
||||
print(
|
||||
"Loading params for {} search iterations form {}".format(
|
||||
self.total_search_iterations, self.serialised_ranges_path
|
||||
)
|
||||
)
|
||||
|
||||
if os.path.exists(self.serialised_ranges_folder):
|
||||
df = pd.read_csv(self.serialised_ranges_path, index_col=0)
|
||||
else:
|
||||
print("Unable to load - regenerating serach ranges instead")
|
||||
df = self.update_serialised_hyperparam_df()
|
||||
|
||||
return df
|
||||
|
||||
def update_serialised_hyperparam_df(self):
|
||||
"""Regenerates hyperparameter combinations and saves to file.
|
||||
|
||||
Returns:
|
||||
DataFrame containing hyperparameter combinations.
|
||||
"""
|
||||
search_df = self._generate_full_hyperparam_df()
|
||||
|
||||
print(
|
||||
"Serialising params for {} search iterations to {}".format(
|
||||
self.total_search_iterations, self.serialised_ranges_path
|
||||
)
|
||||
)
|
||||
|
||||
search_df.to_csv(self.serialised_ranges_path)
|
||||
|
||||
return search_df
|
||||
|
||||
def _generate_full_hyperparam_df(self):
|
||||
"""Generates actual hyperparameter combinations.
|
||||
|
||||
Returns:
|
||||
DataFrame containing hyperparameter combinations.
|
||||
"""
|
||||
|
||||
np.random.seed(131) # for reproducibility of hyperparam list
|
||||
|
||||
name_list = []
|
||||
param_list = []
|
||||
for _ in range(self.total_search_iterations):
|
||||
params = super().get_next_parameters(name_list)
|
||||
|
||||
name = self._get_name(params)
|
||||
|
||||
name_list.append(name)
|
||||
param_list.append(params)
|
||||
|
||||
full_search_df = pd.DataFrame(param_list, index=name_list)
|
||||
|
||||
return full_search_df
|
||||
|
||||
def clear(self): # reset when cleared
|
||||
"""Clears results for hyperparameter manager and resets."""
|
||||
super().clear()
|
||||
self.worker_search_queue = self._get_worker_search_queue()
|
||||
|
||||
def load_results(self):
|
||||
"""Load results from file and queue parameter combinations to try.
|
||||
|
||||
Returns:
|
||||
Boolean indicating if results were successfully loaded.
|
||||
"""
|
||||
success = super().load_results()
|
||||
|
||||
if success:
|
||||
self.worker_search_queue = self._get_worker_search_queue()
|
||||
|
||||
return success
|
||||
|
||||
def _get_worker_search_queue(self):
|
||||
"""Generates the queue of param combinations for current worker.
|
||||
|
||||
Returns:
|
||||
Queue of hyperparameter combinations outstanding.
|
||||
"""
|
||||
global_df = self.assign_worker_numbers(self.global_hyperparam_df)
|
||||
worker_df = global_df[global_df["worker"] == self.worker_num]
|
||||
|
||||
left_overs = [s for s in worker_df.index if s not in self.results.columns]
|
||||
|
||||
return Deque(left_overs)
|
||||
|
||||
def assign_worker_numbers(self, df):
|
||||
"""Updates parameter combinations with the index of the worker used.
|
||||
|
||||
Args:
|
||||
df: DataFrame of parameter combinations.
|
||||
|
||||
Returns:
|
||||
Updated DataFrame with worker number.
|
||||
"""
|
||||
output = df.copy()
|
||||
|
||||
n = self.total_search_iterations
|
||||
batch_size = self.num_iterations_per_worker
|
||||
|
||||
max_worker_num = int(np.ceil(n / batch_size))
|
||||
|
||||
worker_idx = np.concatenate([np.tile(i + 1, self.num_iterations_per_worker) for i in range(max_worker_num)])
|
||||
|
||||
output["worker"] = worker_idx[: len(output)]
|
||||
|
||||
return output
|
||||
1280
examples/benchmarks/TFT/libs/tft_model.py
Normal file
224
examples/benchmarks/TFT/libs/utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Generic helper functions used across codebase."""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
|
||||
|
||||
|
||||
# Generic.
|
||||
def get_single_col_by_input_type(input_type, column_definition):
|
||||
"""Returns name of single column.
|
||||
|
||||
Args:
|
||||
input_type: Input type of column to extract
|
||||
column_definition: Column definition list for experiment
|
||||
"""
|
||||
|
||||
l = [tup[0] for tup in column_definition if tup[2] == input_type]
|
||||
|
||||
if len(l) != 1:
|
||||
raise ValueError("Invalid number of columns for {}".format(input_type))
|
||||
|
||||
return l[0]
|
||||
|
||||
|
||||
def extract_cols_from_data_type(data_type, column_definition, excluded_input_types):
|
||||
"""Extracts the names of columns that correspond to a define data_type.
|
||||
|
||||
Args:
|
||||
data_type: DataType of columns to extract.
|
||||
column_definition: Column definition to use.
|
||||
excluded_input_types: Set of input types to exclude
|
||||
|
||||
Returns:
|
||||
List of names for columns with data type specified.
|
||||
"""
|
||||
return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types]
|
||||
|
||||
|
||||
# Loss functions.
|
||||
def tensorflow_quantile_loss(y, y_pred, quantile):
|
||||
"""Computes quantile loss for tensorflow.
|
||||
|
||||
Standard quantile loss as defined in the "Training Procedure" section of
|
||||
the main TFT paper
|
||||
|
||||
Args:
|
||||
y: Targets
|
||||
y_pred: Predictions
|
||||
quantile: Quantile to use for loss calculations (between 0 & 1)
|
||||
|
||||
Returns:
|
||||
Tensor for quantile loss.
|
||||
"""
|
||||
|
||||
# Checks quantile
|
||||
if quantile < 0 or quantile > 1:
|
||||
raise ValueError("Illegal quantile value={}! Values should be between 0 and 1.".format(quantile))
|
||||
|
||||
prediction_underflow = y - y_pred
|
||||
q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum(
|
||||
-prediction_underflow, 0.0
|
||||
)
|
||||
|
||||
return tf.reduce_sum(q_loss, axis=-1)
|
||||
|
||||
|
||||
def numpy_normalised_quantile_loss(y, y_pred, quantile):
|
||||
"""Computes normalised quantile loss for numpy arrays.
|
||||
|
||||
Uses the q-Risk metric as defined in the "Training Procedure" section of the
|
||||
main TFT paper.
|
||||
|
||||
Args:
|
||||
y: Targets
|
||||
y_pred: Predictions
|
||||
quantile: Quantile to use for loss calculations (between 0 & 1)
|
||||
|
||||
Returns:
|
||||
Float for normalised quantile loss.
|
||||
"""
|
||||
prediction_underflow = y - y_pred
|
||||
weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum(
|
||||
-prediction_underflow, 0.0
|
||||
)
|
||||
|
||||
quantile_loss = weighted_errors.mean()
|
||||
normaliser = y.abs().mean()
|
||||
|
||||
return 2 * quantile_loss / normaliser
|
||||
|
||||
|
||||
# OS related functions.
|
||||
def create_folder_if_not_exist(directory):
|
||||
"""Creates folder if it doesn't exist.
|
||||
|
||||
Args:
|
||||
directory: Folder path to create.
|
||||
"""
|
||||
# Also creates directories recursively
|
||||
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Tensorflow related functions.
|
||||
def get_default_tensorflow_config(tf_device="gpu", gpu_id=0):
|
||||
"""Creates tensorflow config for graphs to run on CPU or GPU.
|
||||
|
||||
Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi
|
||||
GPU machines.
|
||||
|
||||
Args:
|
||||
tf_device: 'cpu' or 'gpu'
|
||||
gpu_id: GPU ID to use if relevant
|
||||
|
||||
Returns:
|
||||
Tensorflow config.
|
||||
"""
|
||||
|
||||
if tf_device == "cpu":
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # for training on cpu
|
||||
tf_config = tf.ConfigProto(log_device_placement=False, device_count={"GPU": 0})
|
||||
|
||||
else:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
print("Selecting GPU ID={}".format(gpu_id))
|
||||
|
||||
tf_config = tf.ConfigProto(log_device_placement=False)
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
|
||||
return tf_config
|
||||
|
||||
|
||||
def save(tf_session, model_folder, cp_name, scope=None):
|
||||
"""Saves Tensorflow graph to checkpoint.
|
||||
|
||||
Saves all trainiable variables under a given variable scope to checkpoint.
|
||||
|
||||
Args:
|
||||
tf_session: Session containing graph
|
||||
model_folder: Folder to save models
|
||||
cp_name: Name of Tensorflow checkpoint
|
||||
scope: Variable scope containing variables to save
|
||||
"""
|
||||
# Save model
|
||||
if scope is None:
|
||||
saver = tf.train.Saver()
|
||||
else:
|
||||
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
|
||||
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
|
||||
|
||||
save_path = saver.save(tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name)))
|
||||
print("Model saved to: {0}".format(save_path))
|
||||
|
||||
|
||||
def load(tf_session, model_folder, cp_name, scope=None, verbose=False):
|
||||
"""Loads Tensorflow graph from checkpoint.
|
||||
|
||||
Args:
|
||||
tf_session: Session to load graph into
|
||||
model_folder: Folder containing serialised model
|
||||
cp_name: Name of Tensorflow checkpoint
|
||||
scope: Variable scope to use.
|
||||
verbose: Whether to print additional debugging information.
|
||||
"""
|
||||
# Load model proper
|
||||
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
|
||||
|
||||
print("Loading model from {0}".format(load_path))
|
||||
|
||||
print_weights_in_checkpoint(model_folder, cp_name)
|
||||
|
||||
initial_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
|
||||
|
||||
# Saver
|
||||
if scope is None:
|
||||
saver = tf.train.Saver()
|
||||
else:
|
||||
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
|
||||
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
|
||||
# Load
|
||||
saver.restore(tf_session, load_path)
|
||||
all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
|
||||
|
||||
if verbose:
|
||||
print("Restored {0}".format(",".join(initial_vars.difference(all_vars))))
|
||||
print("Existing {0}".format(",".join(all_vars.difference(initial_vars))))
|
||||
print("All {0}".format(",".join(all_vars)))
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
def print_weights_in_checkpoint(model_folder, cp_name):
|
||||
"""Prints all weights in Tensorflow checkpoint.
|
||||
|
||||
Args:
|
||||
model_folder: Folder containing checkpoint
|
||||
cp_name: Name of checkpoint
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
|
||||
|
||||
print_tensors_in_checkpoint_file(file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True)
|
||||
3
examples/benchmarks/TFT/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
tensorflow-gpu==1.15.0
|
||||
numpy == 1.19.4
|
||||
pandas==1.1.0
|
||||
249
examples/benchmarks/TFT/tft.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
import data_formatters.base
|
||||
import expt_settings.configs
|
||||
import libs.hyperparam_opt
|
||||
import libs.tft_model
|
||||
import libs.utils as utils
|
||||
import os
|
||||
import datetime as dte
|
||||
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
# To register new datasets, please add them here.
|
||||
ALLOW_DATASET = ["Alpha158"]
|
||||
DATASET_SETTING = {
|
||||
"Alpha158": {
|
||||
"feature_col": ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "ROC60", "RESI10"],
|
||||
"label_col": ["LABEL0"],
|
||||
},
|
||||
}
|
||||
# To register new datasets, please add their configurations here.
|
||||
|
||||
|
||||
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
|
||||
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
|
||||
|
||||
|
||||
def fill_test_na(test_df):
|
||||
test_df_res = test_df.copy()
|
||||
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
|
||||
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
|
||||
test_df_res.loc[:, feature_cols] = test_feature_fna
|
||||
return test_df_res
|
||||
|
||||
|
||||
def process_qlib_data(df, dataset, fillna=False):
|
||||
"""Prepare data to fit the TFT model.
|
||||
|
||||
Args:
|
||||
df: Original DataFrame.
|
||||
fillna: Whether to fill the data with the mean values.
|
||||
|
||||
Returns:
|
||||
Transformed DataFrame.
|
||||
|
||||
"""
|
||||
# Several features selected manually
|
||||
feature_col = DATASET_SETTING[dataset]["feature_col"]
|
||||
label_col = DATASET_SETTING[dataset]["label_col"]
|
||||
temp_df = df.loc[:, feature_col + label_col]
|
||||
if fillna:
|
||||
temp_df = fill_test_na(temp_df)
|
||||
temp_df = temp_df.swaplevel()
|
||||
temp_df = temp_df.sort_index()
|
||||
temp_df = temp_df.reset_index(level=0)
|
||||
dates = pd.to_datetime(temp_df.index)
|
||||
temp_df["date"] = dates
|
||||
temp_df["day_of_week"] = dates.dayofweek
|
||||
temp_df["month"] = dates.month
|
||||
temp_df["year"] = dates.year
|
||||
temp_df["const"] = 1.0
|
||||
return temp_df
|
||||
|
||||
|
||||
def process_predicted(df, col_name):
|
||||
"""Transform the TFT predicted data into Qlib format.
|
||||
|
||||
Args:
|
||||
df: Original DataFrame.
|
||||
fillna: New column name.
|
||||
|
||||
Returns:
|
||||
Transformed DataFrame.
|
||||
|
||||
"""
|
||||
df_res = df.copy()
|
||||
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
|
||||
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
|
||||
df_res = df_res[[col_name]]
|
||||
return df_res
|
||||
|
||||
|
||||
def format_score(forecast_df, col_name="pred", label_shift=5):
|
||||
pred = process_predicted(forecast_df, col_name=col_name)
|
||||
pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
|
||||
pred = pred.dropna()[col_name]
|
||||
return pred
|
||||
|
||||
|
||||
def transform_df(df, col_name="LABEL0"):
|
||||
df_res = df["feature"]
|
||||
df_res[col_name] = df["label"]
|
||||
return df_res
|
||||
|
||||
|
||||
class TFTModel(ModelFT):
|
||||
"""TFT Model"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.model = None
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
return transform_df(df_train), transform_df(df_valid)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
DATASET="Alpha158",
|
||||
MODEL_FOLDER="qlib_alpha158_model",
|
||||
LABEL_COL="LABEL0",
|
||||
LABEL_SHIFT=5,
|
||||
USE_GPU_ID=0,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if DATASET not in ALLOW_DATASET:
|
||||
raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
|
||||
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
|
||||
dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
|
||||
|
||||
train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
|
||||
valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
|
||||
|
||||
ExperimentConfig = expt_settings.configs.ExperimentConfig
|
||||
config = ExperimentConfig(DATASET)
|
||||
self.data_formatter = config.make_data_formatter()
|
||||
self.model_folder = MODEL_FOLDER
|
||||
self.gpu_id = USE_GPU_ID
|
||||
self.label_shift = LABEL_SHIFT
|
||||
self.expt_name = DATASET
|
||||
self.label_col = LABEL_COL
|
||||
|
||||
use_gpu = (True, self.gpu_id)
|
||||
# ===========================Training Process===========================
|
||||
ModelClass = libs.tft_model.TemporalFusionTransformer
|
||||
if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
|
||||
raise ValueError(
|
||||
"Data formatters should inherit from"
|
||||
+ "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
|
||||
)
|
||||
|
||||
default_keras_session = tf.keras.backend.get_session()
|
||||
|
||||
if use_gpu[0]:
|
||||
self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
|
||||
else:
|
||||
self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
|
||||
|
||||
self.data_formatter.set_scalers(train)
|
||||
|
||||
# Sets up default params
|
||||
fixed_params = self.data_formatter.get_experiment_params()
|
||||
params = self.data_formatter.get_default_model_params()
|
||||
|
||||
# Wendi: 合并调优的参数和非调优的参数
|
||||
params = {**params, **fixed_params}
|
||||
|
||||
if not os.path.exists(self.model_folder):
|
||||
os.makedirs(self.model_folder)
|
||||
params["model_folder"] = self.model_folder
|
||||
|
||||
print("*** Begin training ***")
|
||||
best_loss = np.Inf
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
self.tf_graph = tf.Graph()
|
||||
with self.tf_graph.as_default():
|
||||
self.sess = tf.Session(config=self.tf_config)
|
||||
tf.keras.backend.set_session(self.sess)
|
||||
self.model = ModelClass(params, use_cudnn=use_gpu[0])
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.model.fit(train_df=train, valid_df=valid)
|
||||
print("*** Finished training ***")
|
||||
saved_model_dir = self.model_folder + "/" + "saved_model"
|
||||
if not os.path.exists(saved_model_dir):
|
||||
os.makedirs(saved_model_dir)
|
||||
self.model.save(saved_model_dir)
|
||||
|
||||
def extract_numerical_data(data):
|
||||
"""Strips out forecast time and identifier columns."""
|
||||
return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
|
||||
|
||||
# p50_loss = utils.numpy_normalised_quantile_loss(
|
||||
# extract_numerical_data(targets), extract_numerical_data(p50_forecast),
|
||||
# 0.5)
|
||||
# p90_loss = utils.numpy_normalised_quantile_loss(
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
d_test = dataset.prepare("test", col_set=["feature", "label"])
|
||||
d_test = transform_df(d_test)
|
||||
d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
|
||||
test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
|
||||
|
||||
use_gpu = (True, self.gpu_id)
|
||||
# ===========================Predicting Process===========================
|
||||
default_keras_session = tf.keras.backend.get_session()
|
||||
|
||||
# Sets up default params
|
||||
fixed_params = self.data_formatter.get_experiment_params()
|
||||
params = self.data_formatter.get_default_model_params()
|
||||
params = {**params, **fixed_params}
|
||||
|
||||
print("*** Begin predicting ***")
|
||||
tf.reset_default_graph()
|
||||
|
||||
with self.tf_graph.as_default():
|
||||
tf.keras.backend.set_session(self.sess)
|
||||
output_map = self.model.predict(test, return_targets=True)
|
||||
targets = self.data_formatter.format_predictions(output_map["targets"])
|
||||
p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
|
||||
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
|
||||
predict50 = format_score(p50_forecast, "pred", 1)
|
||||
predict90 = format_score(p90_forecast, "pred", 1)
|
||||
predict = (predict50 + predict90) / 2 # self.label_shift
|
||||
# ===========================Predicting Process===========================
|
||||
return predict
|
||||
|
||||
def finetune(self, dataset: DatasetH):
|
||||
"""
|
||||
finetune model
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
52
examples/benchmarks/TFT/workflow_config_tft.yaml
Normal file
@@ -0,0 +1,52 @@
|
||||
sys:
|
||||
rel_path: .
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TFTModel
|
||||
module_path: tft
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
3
examples/benchmarks/XGBoost/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# XGBoost
|
||||
* Code: [https://github.com/dmlc/xgboost](https://github.com/dmlc/xgboost)
|
||||
* Paper: XGBoost: A Scalable Tree Boosting System. [https://dl.acm.org/doi/pdf/10.1145/2939672.2939785](https://dl.acm.org/doi/pdf/10.1145/2939672.2939785).
|
||||
3
examples/benchmarks/XGBoost/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
xgboost==1.2.1
|
||||
63
examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: XGBModel
|
||||
module_path: qlib.contrib.model.xgboost
|
||||
kwargs:
|
||||
eval_metric: rmse
|
||||
colsample_bytree: 0.8879
|
||||
eta: 0.0421
|
||||
max_depth: 8
|
||||
n_estimators: 647
|
||||
subsample: 0.8789
|
||||
nthread: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,222 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import json\n",
|
||||
"import yaml\n",
|
||||
"import pickle\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.utils import exists_qlib_data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"CUR_DIR = Path.cwd()\n",
|
||||
"MARKET = \"csi300\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# use default data\n",
|
||||
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data\n",
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"if not exists_qlib_data(provider_uri):\n",
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with CUR_DIR.joinpath('estimator_config.yaml').open() as fp:\n",
|
||||
" estimator_name = yaml.load(fp, Loader=yaml.FullLoader)['experiment']['name']\n",
|
||||
"with CUR_DIR.joinpath(estimator_name, 'exp_info.json').open() as fp:\n",
|
||||
" latest_id = json.load(fp)['id']\n",
|
||||
" \n",
|
||||
"estimator_dir = CUR_DIR.joinpath(estimator_name, 'sacred', latest_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# read estimator result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pred_df = pd.read_pickle(estimator_dir.joinpath('pred.pkl'))\n",
|
||||
"report_normal_df = pd.read_pickle(estimator_dir.joinpath('report_normal.pkl'))\n",
|
||||
"report_normal_df.index.names = ['index']\n",
|
||||
"\n",
|
||||
"analysis_df = pd.read_pickle(estimator_dir.joinpath('analysis.pkl'))\n",
|
||||
"positions = pickle.load(estimator_dir.joinpath('positions.pkl').open('rb'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# analyze graphs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.data import D\n",
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## analysis position"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"stock_ret = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
|
||||
"stock_ret.columns = ['label']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### report"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.report_graph(report_normal_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### risk analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## analysis model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_df = D.features(D.instruments(MARKET), ['Ref($close, -2)/Ref($close, -1) - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
|
||||
"label_df.columns = ['label']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### score IC"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
|
||||
"analysis_position.score_ic_graph(pred_label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### model performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_model.model_performance_graph(pred_label)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
experiment:
|
||||
name: estimator_example
|
||||
observer_type: file_storage
|
||||
mode: train
|
||||
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
args:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
data:
|
||||
class: Alpha158
|
||||
args:
|
||||
dropna_label: True
|
||||
filter:
|
||||
market: csi300
|
||||
trainer:
|
||||
class: StaticTrainer
|
||||
args:
|
||||
train_start_date: 2008-01-01
|
||||
train_end_date: 2014-12-31
|
||||
validate_start_date: 2015-01-01
|
||||
validate_end_date: 2016-12-31
|
||||
test_start_date: 2017-01-01
|
||||
test_end_date: 2020-08-01
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
args:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
normal_backtest_args:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: SH000300
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
qlib_data:
|
||||
# when testing, please modify the following parameters according to the specific environment
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: "cn"
|
||||
@@ -1,55 +0,0 @@
|
||||
experiment:
|
||||
name: estimator_example
|
||||
observer_type: file_storage
|
||||
mode: train
|
||||
|
||||
model:
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
class: DNNModelPytorch
|
||||
args:
|
||||
loss: mse
|
||||
input_dim: 158
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: 'adam'
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: '0'
|
||||
data:
|
||||
class: Alpha158
|
||||
args:
|
||||
dropna_label: True
|
||||
dropna_feature: True
|
||||
filter:
|
||||
market: csi300
|
||||
trainer:
|
||||
class: StaticTrainer
|
||||
args:
|
||||
train_start_date: 2007-01-01
|
||||
train_end_date: 2014-12-31
|
||||
validate_start_date: 2015-01-01
|
||||
validate_end_date: 2016-12-31
|
||||
test_start_date: 2017-01-01
|
||||
test_end_date: 2020-08-01
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
args:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
normal_backtest_args:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: SH000300
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
|
||||
qlib_data:
|
||||
# when testing, please modify the following parameters according to the specific environment
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: "cn"
|
||||
284
examples/run_all_model.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import fire
|
||||
import time
|
||||
import venv
|
||||
import glob
|
||||
import shutil
|
||||
import signal
|
||||
import inspect
|
||||
import tempfile
|
||||
import traceback
|
||||
import functools
|
||||
import statistics
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from pprint import pprint
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.cli import workflow
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
# init qlib
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
exp_folder_name = "run_all_model_records"
|
||||
exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
|
||||
exp_manager = {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + exp_path,
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
}
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
|
||||
if os.path.isdir(exp_path):
|
||||
shutil.rmtree(exp_path)
|
||||
|
||||
# decorator to check the arguments
|
||||
def only_allow_defined_args(function_to_decorate):
|
||||
@functools.wraps(function_to_decorate)
|
||||
def _return_wrapped(*args, **kwargs):
|
||||
"""Internal wrapper function."""
|
||||
argspec = inspect.getfullargspec(function_to_decorate)
|
||||
valid_names = set(argspec.args + argspec.kwonlyargs)
|
||||
if "self" in valid_names:
|
||||
valid_names.remove("self")
|
||||
for arg_name in kwargs:
|
||||
if arg_name not in valid_names:
|
||||
raise ValueError("Unknown argument seen '%s', expected: [%s]" % (arg_name, ", ".join(valid_names)))
|
||||
return function_to_decorate(*args, **kwargs)
|
||||
|
||||
return _return_wrapped
|
||||
|
||||
|
||||
# function to handle ctrl z and ctrl c
|
||||
def handler(signum, frame):
|
||||
os.system("kill -9 %d" % os.getpid())
|
||||
|
||||
|
||||
signal.signal(signal.SIGTSTP, handler)
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
|
||||
# function to calculate the mean and std of a list in the results dictionary
|
||||
def cal_mean_std(results) -> dict:
|
||||
mean_std = dict()
|
||||
for fn in results:
|
||||
mean_std[fn] = dict()
|
||||
for metric in results[fn]:
|
||||
mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0]
|
||||
std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0
|
||||
mean_std[fn][metric] = [mean, std]
|
||||
return mean_std
|
||||
|
||||
|
||||
# function to create the environment ofr an anaconda environment
|
||||
def create_env():
|
||||
# create env
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
env_path = Path(temp_dir).absolute()
|
||||
sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
|
||||
execute(f"conda create --prefix {env_path} python=3.7 -y")
|
||||
python_path = env_path / "bin" / "python" # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# get anaconda activate path
|
||||
conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME!
|
||||
return env_path, python_path, conda_activate
|
||||
|
||||
|
||||
# function to execute the cmd
|
||||
def execute(cmd):
|
||||
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
|
||||
for line in p.stdout:
|
||||
sys.stdout.write(line.split("\b")[0])
|
||||
if "\b" in line:
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.1)
|
||||
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
|
||||
|
||||
if p.returncode != 0:
|
||||
return p.stderr
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
# function to get all the folders benchmark folder
|
||||
def get_all_folders(models, exclude) -> dict:
|
||||
folders = dict()
|
||||
if isinstance(models, str):
|
||||
model_list = models.split(",")
|
||||
models = [m.lower().strip("[ ]") for m in model_list]
|
||||
elif isinstance(models, list):
|
||||
models = [m.lower() for m in models]
|
||||
elif models is None:
|
||||
models = [f.name.lower() for f in os.scandir("benchmarks")]
|
||||
else:
|
||||
raise ValueError("Input models type is not supported. Please provide str or list without space.")
|
||||
for f in os.scandir("benchmarks"):
|
||||
add = xor(bool(f.name.lower() in models), bool(exclude))
|
||||
if add:
|
||||
path = Path("benchmarks") / f.name
|
||||
folders[f.name] = str(path.resolve())
|
||||
return folders
|
||||
|
||||
|
||||
# function to get all the files under the model folder
|
||||
def get_all_files(folder_path) -> (str, str):
|
||||
yaml_path = str(Path(f"{folder_path}") / "*.yaml")
|
||||
req_path = str(Path(f"{folder_path}") / "*.txt")
|
||||
return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
|
||||
|
||||
|
||||
# function to retrieve all the results
|
||||
def get_all_results(folders) -> dict:
|
||||
results = dict()
|
||||
for fn in folders:
|
||||
exp = R.get_exp(experiment_name=fn, create=False)
|
||||
recorders = exp.list_recorders()
|
||||
result = dict()
|
||||
result["annualized_return_with_cost"] = list()
|
||||
result["information_ratio_with_cost"] = list()
|
||||
result["max_drawdown_with_cost"] = list()
|
||||
for recorder_id in recorders:
|
||||
if recorders[recorder_id].status == "FINISHED":
|
||||
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
|
||||
metrics = recorder.list_metrics()
|
||||
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
|
||||
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
|
||||
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
|
||||
results[fn] = result
|
||||
return results
|
||||
|
||||
|
||||
# function to generate and save markdown table
|
||||
def gen_and_save_md_table(metrics):
|
||||
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
|
||||
table += "|---|---|---|---|\n"
|
||||
for fn in metrics:
|
||||
ar = metrics[fn]["annualized_return_with_cost"]
|
||||
ir = metrics[fn]["information_ratio_with_cost"]
|
||||
md = metrics[fn]["max_drawdown_with_cost"]
|
||||
table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n"
|
||||
pprint(table)
|
||||
with open("table.md", "w") as f:
|
||||
f.write(table)
|
||||
return table
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
@only_allow_defined_args
|
||||
def run(times=1, models=None, exclude=False):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
times : int
|
||||
determines how many times the model should be running.
|
||||
models : str or list
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
|
||||
Usage:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Case 1 - run all models multiple times
|
||||
python run_all_model.py 3
|
||||
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py 3 mlp
|
||||
|
||||
# Case 3 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py 3 [mlp,tft,lstm] True
|
||||
|
||||
# Case 4 - run specific models for one time
|
||||
python run_all_model.py --models=[mlp,lightgbm]
|
||||
|
||||
# Case 5 - run other models except those are given as aruments for one time
|
||||
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
|
||||
|
||||
"""
|
||||
# get all folders
|
||||
folders = get_all_folders(models, exclude)
|
||||
# init error messages:
|
||||
errors = dict()
|
||||
# run all the model for iterations
|
||||
for fn in folders:
|
||||
# create env by anaconda
|
||||
env_path, python_path, conda_activate = create_env()
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn])
|
||||
sys.stderr.write("\n")
|
||||
# install requirements.txt
|
||||
sys.stderr.write("Installing requirements.txt...\n")
|
||||
execute(f"{python_path} -m pip install -r {req_path}")
|
||||
sys.stderr.write("\n")
|
||||
# setup gpu for tft
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
|
||||
)
|
||||
sys.stderr.write("\n")
|
||||
# install qlib
|
||||
sys.stderr.write("Installing qlib...\n")
|
||||
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
|
||||
if fn == "TFT":
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
|
||||
) # TODO: FIX ME!
|
||||
else:
|
||||
execute(
|
||||
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
|
||||
) # TODO: FIX ME!
|
||||
sys.stderr.write("\n")
|
||||
# run workflow_by_config for multiple times
|
||||
for i in range(times):
|
||||
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
|
||||
errs = execute(
|
||||
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
|
||||
)
|
||||
if errs is not None:
|
||||
_errs = errors.get(fn, {})
|
||||
_errs.update({i: errs})
|
||||
errors[fn] = _errs
|
||||
sys.stderr.write("\n")
|
||||
# remove env
|
||||
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
|
||||
shutil.rmtree(env_path)
|
||||
# getting all results
|
||||
sys.stderr.write(f"Retrieving results...\n")
|
||||
results = get_all_results(folders)
|
||||
# calculating the mean and std
|
||||
sys.stderr.write(f"Calculating the mean and std of results...\n")
|
||||
results = cal_mean_std(results)
|
||||
# generating md table
|
||||
sys.stderr.write(f"Generating markdown table...\n")
|
||||
gen_and_save_md_table(results)
|
||||
sys.stderr.write("\n")
|
||||
# print erros
|
||||
sys.stderr.write(f"Here are some of the errors of the models...\n")
|
||||
pprint(errors)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(run) # run all the model
|
||||
@@ -1,121 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.estimator.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "CSI300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"dropna_label": True,
|
||||
"start_date": "2008-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"market": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2008-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
}
|
||||
|
||||
# use default DataHandler
|
||||
# custom DataHandler, refer to: TODO: DataHandler API url
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
|
||||
**TRAINER_CONFIG
|
||||
)
|
||||
|
||||
MODEL_CONFIG = {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
}
|
||||
# use default model
|
||||
# custom Model, refer to: TODO: Model API url
|
||||
model = LGBModel(**MODEL_CONFIG)
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
_pred = model.predict(x_test)
|
||||
_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
|
||||
|
||||
# backtest requires pred_score
|
||||
pred_score = pd.DataFrame(index=_pred.index)
|
||||
pred_score["score"] = _pred.iloc(axis=1)[0]
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,338 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.estimator.handler import Alpha158\n",
|
||||
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
" risk_analysis,\n",
|
||||
")\n",
|
||||
"from qlib.utils import exists_qlib_data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# use default data\n",
|
||||
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n",
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"if not exists_qlib_data(provider_uri):\n",
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"MARKET = \"csi300\"\n",
|
||||
"BENCHMARK = \"SH000300\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# train model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# train model\n",
|
||||
"###################################\n",
|
||||
"DATA_HANDLER_CONFIG = {\n",
|
||||
" \"dropna_label\": True,\n",
|
||||
" \"start_date\": \"2008-01-01\",\n",
|
||||
" \"end_date\": \"2020-08-01\",\n",
|
||||
" \"market\": MARKET,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"TRAINER_CONFIG = {\n",
|
||||
" \"train_start_date\": \"2008-01-01\",\n",
|
||||
" \"train_end_date\": \"2014-12-31\",\n",
|
||||
" \"validate_start_date\": \"2015-01-01\",\n",
|
||||
" \"validate_end_date\": \"2016-12-31\",\n",
|
||||
" \"test_start_date\": \"2017-01-01\",\n",
|
||||
" \"test_end_date\": \"2020-08-01\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# use default DataHandler\n",
|
||||
"# custom DataHandler, refer to: TODO: DataHandler api url\n",
|
||||
"x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(**TRAINER_CONFIG)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"MODEL_CONFIG = {\n",
|
||||
" \"loss\": \"mse\",\n",
|
||||
" \"colsample_bytree\": 0.8879,\n",
|
||||
" \"learning_rate\": 0.0421,\n",
|
||||
" \"subsample\": 0.8789,\n",
|
||||
" \"lambda_l1\": 205.6999,\n",
|
||||
" \"lambda_l2\": 580.9768,\n",
|
||||
" \"max_depth\": 8,\n",
|
||||
" \"num_leaves\": 210,\n",
|
||||
" \"num_threads\": 20,\n",
|
||||
"}\n",
|
||||
"# use default model\n",
|
||||
"# custom Model, refer to: TODO: Model api url\n",
|
||||
"model = LGBModel(**MODEL_CONFIG)\n",
|
||||
"model.fit(x_train, y_train, x_validate, y_validate)\n",
|
||||
"_pred = model.predict(x_test)\n",
|
||||
"_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)\n",
|
||||
"\n",
|
||||
"# backtest requires pred_score\n",
|
||||
"pred_score = pd.DataFrame(index=_pred.index)\n",
|
||||
"pred_score[\"score\"] = _pred.iloc(axis=1)[0]\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# backtest"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# backtest\n",
|
||||
"###################################\n",
|
||||
"STRATEGY_CONFIG = {\n",
|
||||
" \"topk\": 50,\n",
|
||||
" \"n_drop\": 5}\n",
|
||||
"BACKTEST_CONFIG = {\n",
|
||||
" \"verbose\": False,\n",
|
||||
" \"limit_threshold\": 0.095,\n",
|
||||
" \"account\": 100000000,\n",
|
||||
" \"benchmark\": BENCHMARK,\n",
|
||||
" \"deal_price\": \"close\",\n",
|
||||
" \"open_cost\": 0.0005,\n",
|
||||
" \"close_cost\": 0.0015,\n",
|
||||
" \"min_cost\": 5,\n",
|
||||
" \n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# use default strategy\n",
|
||||
"# custom Strategy, refer to: TODO: Strategy api url\n",
|
||||
"strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)\n",
|
||||
"report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# analyze"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# analyze\n",
|
||||
"# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb\n",
|
||||
"###################################\n",
|
||||
"analysis = dict()\n",
|
||||
"analysis[\"excess_return_without_cost\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"])\n",
|
||||
"analysis[\"excess_return_with_cost\"] = risk_analysis(\n",
|
||||
" report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"]\n",
|
||||
")\n",
|
||||
"analysis_df = pd.concat(analysis) # type: pd.DataFrame\n",
|
||||
"print(analysis_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# analyze graphs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"from qlib.data import D\n",
|
||||
"pred_df_dates = pred_score.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df = report_normal\n",
|
||||
"positions = positions_normal\n",
|
||||
"pred_df = pred_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## analysis position"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"stock_ret = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
|
||||
"stock_ret.columns = ['label']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### report"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.report_graph(report_normal_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### risk analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## analysis model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_df = D.features(D.instruments(MARKET), ['Ref($close, -2)/Ref($close, -1) - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
|
||||
"label_df.columns = ['label']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### score IC"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
|
||||
"analysis_position.score_ic_graph(pred_label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### model performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_model.model_performance_graph(pred_label)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
380
examples/workflow_by_code.ipynb
Normal file
@@ -0,0 +1,380 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright (c) Microsoft Corporation.\n",
|
||||
"# Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys, site\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" import qlib\n",
|
||||
"except ImportError:\n",
|
||||
" # install qlib\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
"\n",
|
||||
"scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n",
|
||||
"if not scripts_dir.joinpath(\"get_data.py\").exists():\n",
|
||||
" # download get_data.py script\n",
|
||||
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
|
||||
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||
" import requests\n",
|
||||
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
|
||||
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
|
||||
" fp.write(resp.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
" risk_analysis,\n",
|
||||
")\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
"from qlib.utils import flatten_dict\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# use default data\n",
|
||||
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"if not exists_qlib_data(provider_uri):\n",
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(scripts_dir))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"market = \"csi300\"\n",
|
||||
"benchmark = \"SH000300\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# train model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# train model\n",
|
||||
"###################################\n",
|
||||
"data_handler_config = {\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": market,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"task = {\n",
|
||||
" \"model\": {\n",
|
||||
" \"class\": \"LGBModel\",\n",
|
||||
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"loss\": \"mse\",\n",
|
||||
" \"colsample_bytree\": 0.8879,\n",
|
||||
" \"learning_rate\": 0.0421,\n",
|
||||
" \"subsample\": 0.8789,\n",
|
||||
" \"lambda_l1\": 205.6999,\n",
|
||||
" \"lambda_l2\": 580.9768,\n",
|
||||
" \"max_depth\": 8,\n",
|
||||
" \"num_leaves\": 210,\n",
|
||||
" \"num_threads\": 20,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"dataset\": {\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": {\n",
|
||||
" \"class\": \"Alpha158\",\n",
|
||||
" \"module_path\": \"qlib.contrib.data.handler\",\n",
|
||||
" \"kwargs\": data_handler_config,\n",
|
||||
" },\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# model initiaiton\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
"# start exp to train model\n",
|
||||
"with R.start(experiment_name=\"train_model\"):\n",
|
||||
" R.log_params(**flatten_dict(task))\n",
|
||||
" model.fit(dataset)\n",
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
" rid = R.get_recorder().id\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# prediction, backtest & analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# prediction, backtest & analysis\n",
|
||||
"###################################\n",
|
||||
"port_analysis_config = {\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"class\": \"TopkDropoutStrategy\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.strategy\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"topk\": 50,\n",
|
||||
" \"n_drop\": 5,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"backtest\": {\n",
|
||||
" \"verbose\": False,\n",
|
||||
" \"limit_threshold\": 0.095,\n",
|
||||
" \"account\": 100000000,\n",
|
||||
" \"benchmark\": benchmark,\n",
|
||||
" \"deal_price\": \"close\",\n",
|
||||
" \"open_cost\": 0.0005,\n",
|
||||
" \"close_cost\": 0.0015,\n",
|
||||
" \"min_cost\": 5,\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=\"backtest_analysis\"):\n",
|
||||
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
|
||||
" model = recorder.load_object(\"trained_model\")\n",
|
||||
"\n",
|
||||
" # prediction\n",
|
||||
" recorder = R.get_recorder()\n",
|
||||
" ba_rid = recorder.id\n",
|
||||
" sr = SignalRecord(model, dataset, recorder)\n",
|
||||
" sr.generate()\n",
|
||||
"\n",
|
||||
" # backtest & analysis\n",
|
||||
" par = PortAnaRecord(recorder, port_analysis_config)\n",
|
||||
" par.generate()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# analyze graphs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"from qlib.data import D\n",
|
||||
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",
|
||||
"positions = recorder.load_object(\"portfolio_analysis/positions_normal.pkl\")\n",
|
||||
"analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis.pkl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## analysis position"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### report"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.report_graph(report_normal_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### risk analysis"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## analysis model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
|
||||
"label_df.columns = ['label']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### score IC"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
|
||||
"analysis_position.score_ic_graph(pred_label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### model performance"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"analysis_model.model_performance_graph(pred_label)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
120
examples/workflow_by_code.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.strategy",
|
||||
"kwargs": {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
@@ -2,19 +2,20 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
__version__ = "0.5.1.dev0"
|
||||
__version__ = "0.6.0.alpha"
|
||||
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import platform
|
||||
import sys
|
||||
import copy
|
||||
import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from .utils import can_use_cache
|
||||
|
||||
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
|
||||
from .workflow.utils import experiment_exit_handler
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
@@ -22,6 +23,7 @@ def init(default_conf="client", **kwargs):
|
||||
from .data.data import register_all_wrappers
|
||||
from .log import get_module_logger, set_log_with_config
|
||||
from .data.cache import H
|
||||
from .workflow import R, QlibRecorder
|
||||
|
||||
C.reset()
|
||||
H.clear()
|
||||
@@ -34,17 +36,18 @@ def init(default_conf="client", **kwargs):
|
||||
if _logging_config:
|
||||
set_log_with_config(_logging_config)
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
LOG = get_module_logger("Initialization", level=logging.INFO)
|
||||
LOG.info(f"default_conf: {default_conf}.")
|
||||
|
||||
C.set_mode(default_conf)
|
||||
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
|
||||
|
||||
for k, v in kwargs.items():
|
||||
C[k] = v
|
||||
if k not in C:
|
||||
LOG.warning("Unrecognized config %s" % k)
|
||||
|
||||
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
|
||||
C.resolve_path()
|
||||
|
||||
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
|
||||
@@ -61,12 +64,10 @@ def init(default_conf="client", **kwargs):
|
||||
if not os.path.exists(C["provider_uri"]):
|
||||
if C["auto_mount"]:
|
||||
LOG.error(
|
||||
"Invalid provider uri: {}, please check if a valid provider uri has been set. This path does not exist.".format(
|
||||
C["provider_uri"]
|
||||
)
|
||||
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
|
||||
)
|
||||
else:
|
||||
LOG.warning("auto_path is False, please make sure {} is mounted".format(C["mount_path"]))
|
||||
LOG.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
|
||||
elif C.get_uri_type() == QlibConfig.NFS_URI:
|
||||
_mount_nfs_uri(C)
|
||||
else:
|
||||
@@ -80,6 +81,13 @@ def init(default_conf="client", **kwargs):
|
||||
if "flask_server" in C:
|
||||
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
|
||||
|
||||
# set up QlibRecorder
|
||||
exp_manager = init_instance_by_config(C["exp_manager"])
|
||||
qr = QlibRecorder(exp_manager)
|
||||
R.register(qr)
|
||||
# clean up experiment when python program ends
|
||||
experiment_exit_handler()
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
from .log import get_module_logger
|
||||
@@ -94,9 +102,7 @@ def _mount_nfs_uri(C):
|
||||
if not C["auto_mount"]:
|
||||
if not os.path.exists(C["mount_path"]):
|
||||
raise FileNotFoundError(
|
||||
"Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format(
|
||||
C["mount_path"], mount_command
|
||||
)
|
||||
f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
)
|
||||
else:
|
||||
# Judging system type
|
||||
@@ -153,9 +159,7 @@ def _mount_nfs_uri(C):
|
||||
os.makedirs(C["mount_path"], exist_ok=True)
|
||||
except Exception:
|
||||
raise OSError(
|
||||
"Failed to create directory {}, please create {} manually!".format(
|
||||
C["mount_path"], C["mount_path"]
|
||||
)
|
||||
f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
|
||||
)
|
||||
|
||||
# check nfs-common
|
||||
@@ -167,20 +171,18 @@ def _mount_nfs_uri(C):
|
||||
command_status = os.system(mount_command)
|
||||
if command_status == 256:
|
||||
raise OSError(
|
||||
"mount {} on {} error! Needs SUDO! Please mount manually: {}".format(
|
||||
C["provider_uri"], C["mount_path"], mount_command
|
||||
)
|
||||
f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
|
||||
)
|
||||
elif command_status == 32512:
|
||||
# LOG.error("Command error")
|
||||
raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"]))
|
||||
raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
|
||||
elif command_status == 0:
|
||||
LOG.info("Mount finished")
|
||||
else:
|
||||
LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path))
|
||||
LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")
|
||||
|
||||
|
||||
def init_from_yaml_conf(conf_path):
|
||||
def init_from_yaml_conf(conf_path, **kwargs):
|
||||
"""init_from_yaml_conf
|
||||
|
||||
:param conf_path: A path to the qlib config in yml format
|
||||
@@ -188,5 +190,6 @@ def init_from_yaml_conf(conf_path):
|
||||
|
||||
with open(conf_path) as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
|
||||
@@ -14,6 +14,8 @@ Two modes are supported
|
||||
import copy
|
||||
from pathlib import Path
|
||||
import re
|
||||
import os
|
||||
import multiprocessing
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -62,6 +64,8 @@ class Config:
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
|
||||
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
|
||||
|
||||
_default_config = {
|
||||
# data provider config
|
||||
"calendar_provider": "LocalCalendarProvider",
|
||||
@@ -78,7 +82,7 @@ _default_config = {
|
||||
"calendar_cache": None,
|
||||
# for simple dataset cache
|
||||
"local_cache_path": None,
|
||||
"kernels": 16,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
|
||||
"maxtasksperchild": None,
|
||||
"default_disk_cache": 1, # 0:skip/1:use
|
||||
@@ -124,6 +128,15 @@ _default_config = {
|
||||
},
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
"exp_manager": {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
"kwargs": {
|
||||
"uri": "file:" + str(Path(os.getcwd()).resolve() / "mlruns"),
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
@@ -141,10 +154,11 @@ MODE_CONF = {
|
||||
"redis_host": "127.0.0.1",
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
"kernels": 64,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# cache
|
||||
"expression_cache": "DiskExpressionCache",
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
"mount_path": None,
|
||||
},
|
||||
"client": {
|
||||
# data provider config
|
||||
@@ -162,7 +176,7 @@ MODE_CONF = {
|
||||
"dataset_cache": "DiskDatasetCache",
|
||||
"calendar_cache": None,
|
||||
# client config
|
||||
"kernels": 16,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
"mount_path": None,
|
||||
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
@@ -212,7 +226,9 @@ class QlibConfig(Config):
|
||||
|
||||
def get_uri_type(self):
|
||||
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
|
||||
is_nfs_or_win = re.match("^[^/]+:.+", self["provider_uri"]) is not None # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
is_nfs_or_win = (
|
||||
re.match("^[^/]+:.+", self["provider_uri"]) is not None
|
||||
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
|
||||
@@ -10,6 +10,7 @@ from ...data import D
|
||||
from .account import Account
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
from ...data.dataset.utils import get_level_index
|
||||
|
||||
LOG = get_module_logger("backtest")
|
||||
|
||||
@@ -18,7 +19,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
"""Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict should has <instrument, datetime> index and one `score` column
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
|
||||
strategy : Strategy()
|
||||
strategy part for backtest
|
||||
trade_exchange : Exchange()
|
||||
@@ -43,6 +45,12 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
`benchmark` is str, will use the daily change as the 'bench'.
|
||||
benchmark code, default is SH000905 CSI500
|
||||
"""
|
||||
# Convert format if the input format is not expected
|
||||
if get_level_index(pred, level="datetime") == 1:
|
||||
pred = pred.swaplevel().sort_index()
|
||||
if isinstance(pred, pd.Series):
|
||||
pred = pred.to_frame("score")
|
||||
|
||||
trade_account = Account(init_cash=account)
|
||||
_pred_dates = pred.index.get_level_values(level="datetime")
|
||||
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
|
||||
@@ -57,6 +65,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
get_date_by_shift(predict_dates[-1], shift=shift),
|
||||
disk_cache=1,
|
||||
)
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
|
||||
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
|
||||
@@ -71,7 +81,7 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
|
||||
# 1. Load the score_series at pred_date
|
||||
try:
|
||||
score = pred.loc(axis=0)[:, pred_date] # (stock_id, trade_date) multi_index, score in pdate
|
||||
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
|
||||
score_series = score.reset_index(level="datetime", drop=True)[
|
||||
"score"
|
||||
] # pd.Series(index:stock_id, data: score)
|
||||
|
||||
@@ -166,7 +166,7 @@ class Position:
|
||||
def save_position(self, path, last_trade_date):
|
||||
path = pathlib.Path(path)
|
||||
p = copy.deepcopy(self.position)
|
||||
cash = pd.Series()
|
||||
cash = pd.Series(dtype=np.float)
|
||||
cash["init_cash"] = self.init_cash
|
||||
cash["cash"] = p["cash"]
|
||||
cash["today_account_value"] = p["today_account_value"]
|
||||
|
||||
429
qlib/contrib/data/handler.py
Normal file
@@ -0,0 +1,429 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_cls_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
import copy
|
||||
|
||||
|
||||
def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
args = getfullargspec(klass).args
|
||||
if "fit_start_time" in args and "fit_end_time" in args:
|
||||
assert (
|
||||
fit_start_time is not None and fit_end_time is not None
|
||||
), "Make sure `fit_start_time` and `fit_end_time` are not None."
|
||||
pkwargs.update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
|
||||
class ALPHA360_Denoise(DataHandlerLP):
|
||||
def __init__(self, instruments="csi500", start_time=None, end_time=None, fit_start_time=None, fit_end_time=None):
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"label": self.get_label_config(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
learn_processors = [
|
||||
{"class": "DropnaLabel", "kwargs": {"fields_group": "label"}},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
]
|
||||
infer_processors = [
|
||||
{"class": "ProcessInf", "kwargs": {}},
|
||||
{"class": "TanhProcess", "kwargs": {}},
|
||||
{"class": "Fillna", "kwargs": {}},
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
instruments,
|
||||
start_time,
|
||||
end_time,
|
||||
data_loader=data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
)
|
||||
|
||||
def get_label_config(self):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
|
||||
def get_feature_config(self):
|
||||
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % (i)]
|
||||
names += ["CLOSE%d" % (i)]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % (i)]
|
||||
names += ["OPEN%d" % (i)]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % (i)]
|
||||
names += ["HIGH%d" % (i)]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % (i)]
|
||||
names += ["LOW%d" % (i)]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % (i)]
|
||||
names += ["VWAP%d" % (i)]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/$volume" % (i)]
|
||||
names += ["VOLUME%d" % (i)]
|
||||
fields += ["$volume/$volume"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
_DEFAULT_LEARN_PROCESSORS = [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
]
|
||||
_DEFAULT_INFER_PROCESSORS = [
|
||||
{"class": "ProcessInf", "kwargs": {}},
|
||||
{"class": "ZScoreNorm", "kwargs": {}},
|
||||
{"class": "Fillna", "kwargs": {}},
|
||||
]
|
||||
|
||||
|
||||
class ALPHA360(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi500",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=_DEFAULT_INFER_PROCESSORS,
|
||||
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"label": kwargs.get("label", self.get_label_config()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments,
|
||||
start_time,
|
||||
end_time,
|
||||
data_loader=data_loader,
|
||||
learn_processors=learn_processors,
|
||||
infer_processors=infer_processors,
|
||||
)
|
||||
|
||||
def get_label_config(self):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
|
||||
def get_feature_config(self):
|
||||
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($close, %d)/$close" % (i)]
|
||||
names += ["CLOSE%d" % (i)]
|
||||
fields += ["$close/$close"]
|
||||
names += ["CLOSE0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($open, %d)/$close" % (i)]
|
||||
names += ["OPEN%d" % (i)]
|
||||
fields += ["$open/$close"]
|
||||
names += ["OPEN0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($high, %d)/$close" % (i)]
|
||||
names += ["HIGH%d" % (i)]
|
||||
fields += ["$high/$close"]
|
||||
names += ["HIGH0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($low, %d)/$close" % (i)]
|
||||
names += ["LOW%d" % (i)]
|
||||
fields += ["$low/$close"]
|
||||
names += ["LOW0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($vwap, %d)/$close" % (i)]
|
||||
names += ["VWAP%d" % (i)]
|
||||
fields += ["$vwap/$close"]
|
||||
names += ["VWAP0"]
|
||||
for i in range(59, 0, -1):
|
||||
fields += ["Ref($volume, %d)/$volume" % (i)]
|
||||
names += ["VOLUME%d" % (i)]
|
||||
fields += ["$volume/$volume"]
|
||||
names += ["VOLUME0"]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class ALPHA360vwap(ALPHA360):
|
||||
def get_label_config(self):
|
||||
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
|
||||
|
||||
|
||||
class Alpha158(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi500",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {"feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config())},
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
instruments,
|
||||
start_time,
|
||||
end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
conf = {
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
return self.parse_config_to_fields(conf)
|
||||
|
||||
def get_label_config(self):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
|
||||
@staticmethod
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
use = lambda x: x not in exclude and (include is None or x in include)
|
||||
if use("ROC"):
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class Alpha158vwap(Alpha158):
|
||||
def get_label_config(self):
|
||||
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
|
||||
118
qlib/contrib/data/processor.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
from ...log import TimeInspector
|
||||
from ...utils.serial import Serializable
|
||||
from ...data.dataset.processor import Processor, get_group_columns
|
||||
|
||||
|
||||
class ConfigSectionProcessor(Processor):
|
||||
"""
|
||||
This processor is designed for Alpha158. And will be replaced by simple processors in the future
|
||||
"""
|
||||
|
||||
def __init__(self, fields_group=None, **kwargs):
|
||||
super().__init__()
|
||||
# Options
|
||||
self.fillna_feature = kwargs.get("fillna_feature", True)
|
||||
self.fillna_label = kwargs.get("fillna_label", True)
|
||||
self.clip_feature_outlier = kwargs.get("clip_feature_outlier", False)
|
||||
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
|
||||
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
|
||||
|
||||
self.fields_group = None
|
||||
|
||||
def __call__(self, df):
|
||||
return self._transform(df)
|
||||
|
||||
def _transform(self, df):
|
||||
def _label_norm(x):
|
||||
x = x - x.mean() # copy
|
||||
x /= x.std()
|
||||
if self.clip_label_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.fillna_label:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
def _feature_norm(x):
|
||||
x = x - x.median() # copy
|
||||
x /= x.abs().median() * 1.4826
|
||||
if self.clip_feature_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.shrink_feature_outlier:
|
||||
x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)
|
||||
x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
|
||||
if self.fillna_feature:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
# Copy the focus part and change it to single level
|
||||
selected_cols = get_group_columns(df, self.fields_group)
|
||||
df_focus = df[selected_cols].copy()
|
||||
if len(df_focus.columns.levels) > 1:
|
||||
df_focus = df_focus.droplevel(level=0)
|
||||
|
||||
# Label
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm)
|
||||
|
||||
# Features
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
"KSFT",
|
||||
"OPEN",
|
||||
"HIGH",
|
||||
"LOW",
|
||||
"CLOSE",
|
||||
"VWAP",
|
||||
"ROC",
|
||||
"MA",
|
||||
"BETA",
|
||||
"RESI",
|
||||
"QTLU",
|
||||
"QTLD",
|
||||
"RSV",
|
||||
"SUMP",
|
||||
"SUMN",
|
||||
"SUMD",
|
||||
"VSUMP",
|
||||
"VSUMN",
|
||||
"VSUMD",
|
||||
]
|
||||
pat = "|".join(["^" + x for x in _cols])
|
||||
cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")]
|
||||
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")]
|
||||
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
df[selected_cols] = df_focus.values
|
||||
|
||||
TimeInspector.log_cost_time("Finished preprocessing data.")
|
||||
|
||||
return df
|
||||
@@ -1,176 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import yaml
|
||||
import copy
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from ...config import REG_CN
|
||||
|
||||
|
||||
class EstimatorConfigManager(object):
|
||||
def __init__(self, config_path):
|
||||
|
||||
if not config_path:
|
||||
raise ValueError("Config path is invalid.")
|
||||
self.config_path = config_path
|
||||
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
self.ex_config = ExperimentConfig(config.get("experiment", dict()), self)
|
||||
self.data_config = DataConfig(config.get("data", dict()), self)
|
||||
self.model_config = ModelConfig(config.get("model", dict()), self)
|
||||
self.trainer_config = TrainerConfig(config.get("trainer", dict()), self)
|
||||
self.strategy_config = StrategyConfig(config.get("strategy", dict()), self)
|
||||
self.backtest_config = BacktestConfig(config.get("backtest", dict()), self)
|
||||
self.qlib_data_config = QlibDataConfig(config.get("qlib_data", dict()), self)
|
||||
|
||||
# If the start_date and end_date are not given in data_config, they will be referred from the trainer_config.
|
||||
handler_start_date = self.data_config.handler_parameters.get("start_date", None)
|
||||
handler_end_date = self.data_config.handler_parameters.get("end_date", None)
|
||||
if handler_start_date is None:
|
||||
self.data_config.handler_parameters["start_date"] = self.trainer_config.parameters["train_start_date"]
|
||||
if handler_end_date is None:
|
||||
self.data_config.handler_parameters["end_date"] = self.trainer_config.parameters["test_end_date"]
|
||||
|
||||
|
||||
class ExperimentConfig(object):
|
||||
TRAIN_MODE = "train"
|
||||
TEST_MODE = "test"
|
||||
|
||||
OBSERVER_FILE_STORAGE = "file_storage"
|
||||
OBSERVER_MONGO = "mongo"
|
||||
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for experiment
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.name = config.get("name", "test_experiment")
|
||||
# The dir of the result of all the experiments
|
||||
self.global_dir = config.get("dir", os.path.dirname(CONFIG_MANAGER.config_path))
|
||||
# The dir of the result of current experiment
|
||||
self.ex_dir = os.path.join(self.global_dir, self.name)
|
||||
if not os.path.exists(self.ex_dir):
|
||||
os.makedirs(self.ex_dir)
|
||||
self.tmp_run_dir = tempfile.mkdtemp(dir=self.ex_dir)
|
||||
self.mode = config.get("mode", ExperimentConfig.TRAIN_MODE)
|
||||
self.sacred_dir = os.path.join(self.ex_dir, "sacred")
|
||||
self.observer_type = config.get("observer_type", ExperimentConfig.OBSERVER_FILE_STORAGE)
|
||||
self.mongo_url = config.get("mongo_url", None)
|
||||
self.db_name = config.get("db_name", None)
|
||||
self.finetune = config.get("finetune", False)
|
||||
|
||||
# The path of the experiment id of the experiment
|
||||
self.exp_info_path = config.get("exp_info_path", os.path.join(self.ex_dir, "exp_info.json"))
|
||||
exp_info_dir = Path(self.exp_info_path).parent
|
||||
exp_info_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Test mode config
|
||||
loader_args = config.get("loader", dict())
|
||||
if self.mode == ExperimentConfig.TEST_MODE or self.finetune:
|
||||
loader_exp_info_path = loader_args.get("exp_info_path", None)
|
||||
self.loader_model_index = loader_args.get("model_index", None)
|
||||
if (loader_exp_info_path is not None) and (os.path.exists(loader_exp_info_path)):
|
||||
with open(loader_exp_info_path) as fp:
|
||||
loader_dict = json.load(fp)
|
||||
for k, v in loader_dict.items():
|
||||
setattr(self, "loader_{}".format(k), v)
|
||||
# Check loader experiment id
|
||||
assert hasattr(self, "loader_id"), "If mode is test or finetune is True, loader must contain id."
|
||||
else:
|
||||
self.loader_id = loader_args.get("id", None)
|
||||
if self.loader_id is None:
|
||||
raise ValueError("If mode is test or finetune is True, loader must contain id.")
|
||||
|
||||
self.loader_observer_type = loader_args.get("observer_type", self.observer_type)
|
||||
self.loader_name = loader_args.get("name", self.name)
|
||||
self.loader_dir = loader_args.get("dir", self.global_dir)
|
||||
|
||||
self.loader_mongo_url = loader_args.get("mongo_url", self.mongo_url)
|
||||
self.loader_db_name = loader_args.get("db_name", self.db_name)
|
||||
|
||||
|
||||
class DataConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for data
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.handler_module_path = config.get("module_path", "qlib.contrib.estimator.handler")
|
||||
self.handler_class = config.get("class", "ALPHA360")
|
||||
self.handler_parameters = config.get("args", dict())
|
||||
self.handler_filter = config.get("filter", dict())
|
||||
# Update provider uri.
|
||||
|
||||
|
||||
class ModelConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for model
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.model_class = config.get("class", "Model")
|
||||
self.model_module_path = config.get("module_path", "qlib.contrib.model")
|
||||
self.save_dir = os.path.join(CONFIG_MANAGER.ex_config.tmp_run_dir, "model")
|
||||
self.save_path = config.get("save_path", os.path.join(self.save_dir, "model.bin"))
|
||||
self.parameters = config.get("args", dict())
|
||||
# Make dir if need.
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
|
||||
|
||||
class TrainerConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for trainer
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.trainer_class = config.get("class", "StaticTrainer")
|
||||
self.trainer_module_path = config.get("module_path", "qlib.contrib.estimator.trainer")
|
||||
self.parameters = config.get("args", dict())
|
||||
|
||||
|
||||
class StrategyConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for strategy
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.strategy_class = config.get("class", "TopkDropoutStrategy")
|
||||
self.strategy_module_path = config.get("module_path", "qlib.contrib.strategy.strategy")
|
||||
self.parameters = config.get("args", dict())
|
||||
|
||||
|
||||
class BacktestConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGE):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for strategy
|
||||
:param CONFIG_MANAGE: The estimator config manager
|
||||
"""
|
||||
self.normal_backtest_parameters = config.get("normal_backtest_args", dict())
|
||||
self.long_short_backtest_parameters = config.get("long_short_backtest_args", dict())
|
||||
|
||||
|
||||
class QlibDataConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGE):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for qlib_client
|
||||
:param CONFIG_MANAGE: The estimator config manager
|
||||
"""
|
||||
self.provider_uri = config.pop("provider_uri", "~/.qlib/qlib_data/cn_data")
|
||||
self.auto_mount = config.pop("auto_mount", False)
|
||||
self.mount_path = config.pop("mount_path", "~/.qlib/qlib_data/cn_data")
|
||||
self.region = config.pop("region", REG_CN)
|
||||
self.args = config
|
||||
@@ -1,328 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import yaml
|
||||
import pickle
|
||||
|
||||
import qlib
|
||||
from ..evaluate import risk_analysis
|
||||
from ..evaluate import backtest as normal_backtest
|
||||
from ..evaluate import long_short_backtest
|
||||
from .config import ExperimentConfig
|
||||
from .fetcher import create_fetcher_with_config
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_module_by_module_path, compare_dict_value
|
||||
|
||||
|
||||
class Estimator(object):
|
||||
def __init__(self, config_manager, sacred_ex):
|
||||
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("Estimator")
|
||||
|
||||
# 1. Set config manager.
|
||||
self.config_manager = config_manager
|
||||
|
||||
# 2. Set configs.
|
||||
self.ex_config = config_manager.ex_config
|
||||
self.data_config = config_manager.data_config
|
||||
self.model_config = config_manager.model_config
|
||||
self.trainer_config = config_manager.trainer_config
|
||||
self.strategy_config = config_manager.strategy_config
|
||||
self.backtest_config = config_manager.backtest_config
|
||||
|
||||
# If experiment.mode is test or experiment.finetune is True, load the experimental results in the loader
|
||||
if self.ex_config.mode == self.ex_config.TEST_MODE or self.ex_config.finetune:
|
||||
self.compare_config_with_config_manger(self.config_manager)
|
||||
|
||||
# 3. Set sacred_experiment.
|
||||
self.ex = sacred_ex
|
||||
|
||||
# 4. Init data handler.
|
||||
self.data_handler = None
|
||||
self._init_data_handler()
|
||||
|
||||
# 5. Init trainer.
|
||||
self.trainer = None
|
||||
self._init_trainer()
|
||||
|
||||
# 6. Init strategy.
|
||||
self.strategy = None
|
||||
self._init_strategy()
|
||||
|
||||
def _init_data_handler(self):
|
||||
handler_module = get_module_by_module_path(self.data_config.handler_module_path)
|
||||
|
||||
# Set market
|
||||
market = self.data_config.handler_filter.get("market", None)
|
||||
if market is None:
|
||||
if "market" in self.data_config.handler_parameters:
|
||||
self.logger.warning(
|
||||
"Warning: The market in data.args section is deprecated. "
|
||||
"It only works when market is not set in data.filter section. "
|
||||
"It will be overridden by market in the data.filter section."
|
||||
)
|
||||
market = self.data_config.handler_parameters["market"]
|
||||
else:
|
||||
market = "csi500"
|
||||
|
||||
self.data_config.handler_parameters["market"] = market
|
||||
|
||||
data_filter_list = []
|
||||
handler_filters = self.data_config.handler_filter.get("filter_pipeline", list())
|
||||
for h_filter in handler_filters:
|
||||
filter_module_path = h_filter.get("module_path", "qlib.data.filter")
|
||||
filter_class_name = h_filter.get("class", "")
|
||||
filter_parameters = h_filter.get("args", {})
|
||||
filter_module = get_module_by_module_path(filter_module_path)
|
||||
filter_class = getattr(filter_module, filter_class_name)
|
||||
data_filter = filter_class(**filter_parameters)
|
||||
data_filter_list.append(data_filter)
|
||||
|
||||
self.data_config.handler_parameters["data_filter_list"] = data_filter_list
|
||||
handler_class = getattr(handler_module, self.data_config.handler_class)
|
||||
self.data_handler = handler_class(**self.data_config.handler_parameters)
|
||||
|
||||
def _init_trainer(self):
|
||||
|
||||
model_module = get_module_by_module_path(self.model_config.model_module_path)
|
||||
trainer_module = get_module_by_module_path(self.trainer_config.trainer_module_path)
|
||||
model_class = getattr(model_module, self.model_config.model_class)
|
||||
trainer_class = getattr(trainer_module, self.trainer_config.trainer_class)
|
||||
|
||||
self.trainer = trainer_class(
|
||||
model_class,
|
||||
self.model_config.save_path,
|
||||
self.model_config.parameters,
|
||||
self.data_handler,
|
||||
self.ex,
|
||||
**self.trainer_config.parameters
|
||||
)
|
||||
|
||||
def _init_strategy(self):
|
||||
|
||||
module = get_module_by_module_path(self.strategy_config.strategy_module_path)
|
||||
strategy_class = getattr(module, self.strategy_config.strategy_class)
|
||||
self.strategy = strategy_class(**self.strategy_config.parameters)
|
||||
|
||||
def run(self):
|
||||
if self.ex_config.mode == ExperimentConfig.TRAIN_MODE:
|
||||
self.trainer.train()
|
||||
elif self.ex_config.mode == ExperimentConfig.TEST_MODE:
|
||||
self.trainer.load()
|
||||
else:
|
||||
raise ValueError("unexpected mode: %s" % self.ex_config.mode)
|
||||
analysis = self.backtest()
|
||||
print(analysis)
|
||||
self.logger.info(
|
||||
"experiment id: {}, experiment name: {}".format(self.ex.experiment.current_run._id, self.ex_config.name)
|
||||
)
|
||||
|
||||
# Remove temp dir
|
||||
# shutil.rmtree(self.ex_config.tmp_run_dir)
|
||||
|
||||
def backtest(self):
|
||||
TimeInspector.set_time_mark()
|
||||
# 1. Get pred and prediction score of model(s).
|
||||
pred = self.trainer.get_test_score()
|
||||
try:
|
||||
performance = self.trainer.get_test_performance()
|
||||
except NotImplementedError:
|
||||
performance = None
|
||||
# 2. Normal Backtest.
|
||||
report_normal, positions_normal = self._normal_backtest(pred)
|
||||
# 3. Long-Short Backtest.
|
||||
# Deprecated
|
||||
# long_short_reports = self._long_short_backtest(pred)
|
||||
# 4. Analyze
|
||||
analysis_df = self._analyze(report_normal)
|
||||
# 5. Save.
|
||||
self._save_backtest_result(
|
||||
pred,
|
||||
analysis_df,
|
||||
positions_normal,
|
||||
report_normal,
|
||||
# long_short_reports,
|
||||
performance,
|
||||
)
|
||||
return analysis_df
|
||||
|
||||
def _normal_backtest(self, pred):
|
||||
TimeInspector.set_time_mark()
|
||||
if "account" not in self.backtest_config.normal_backtest_parameters:
|
||||
if "account" in self.strategy_config.parameters:
|
||||
self.logger.warning(
|
||||
"Warning: The account in strategy section is deprecated. "
|
||||
"It only works when account is not set in backtest section. "
|
||||
"It will be overridden by account in the backtest section."
|
||||
)
|
||||
self.backtest_config.normal_backtest_parameters["account"] = self.strategy_config.parameters["account"]
|
||||
report_normal, positions_normal = normal_backtest(
|
||||
pred, strategy=self.strategy, **self.backtest_config.normal_backtest_parameters
|
||||
)
|
||||
TimeInspector.log_cost_time("Finished normal backtest.")
|
||||
return report_normal, positions_normal
|
||||
|
||||
def _long_short_backtest(self, pred):
|
||||
TimeInspector.set_time_mark()
|
||||
long_short_reports = long_short_backtest(pred, **self.backtest_config.long_short_backtest_parameters)
|
||||
TimeInspector.log_cost_time("Finished long-short backtest.")
|
||||
return long_short_reports
|
||||
|
||||
@staticmethod
|
||||
def _analyze(report_normal):
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
analysis = dict()
|
||||
# analysis["pred_long"] = risk_analysis(long_short_reports["long"])
|
||||
# analysis["pred_short"] = risk_analysis(long_short_reports["short"])
|
||||
# analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"])
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
TimeInspector.log_cost_time(
|
||||
"Finished generating analysis," " average turnover is: {0:.4f}.".format(report_normal["turnover"].mean())
|
||||
)
|
||||
return analysis_df
|
||||
|
||||
def _save_backtest_result(self, pred, analysis, positions, report_normal, performance):
|
||||
# 1. Result dir.
|
||||
result_dir = os.path.join(self.config_manager.ex_config.tmp_run_dir, "result")
|
||||
if not os.path.exists(result_dir):
|
||||
os.makedirs(result_dir)
|
||||
|
||||
self.ex.add_info(
|
||||
"task_config",
|
||||
json.loads(json.dumps(self.config_manager.config, default=str)),
|
||||
)
|
||||
|
||||
# 2. Pred.
|
||||
TimeInspector.set_time_mark()
|
||||
pred_pkl_path = os.path.join(result_dir, "pred.pkl")
|
||||
pred.to_pickle(pred_pkl_path)
|
||||
self.ex.add_artifact(pred_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving pred.pkl to: {}".format(pred_pkl_path))
|
||||
|
||||
# 3. Ana.
|
||||
TimeInspector.set_time_mark()
|
||||
analysis_pkl_path = os.path.join(result_dir, "analysis.pkl")
|
||||
analysis.to_pickle(analysis_pkl_path)
|
||||
self.ex.add_artifact(analysis_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving analysis.pkl to: {}".format(analysis_pkl_path))
|
||||
|
||||
# 4. Pos.
|
||||
TimeInspector.set_time_mark()
|
||||
positions_pkl_path = os.path.join(result_dir, "positions.pkl")
|
||||
with open(positions_pkl_path, "wb") as fp:
|
||||
pickle.dump(positions, fp)
|
||||
self.ex.add_artifact(positions_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving positions.pkl to: {}".format(positions_pkl_path))
|
||||
|
||||
# 5. Report normal.
|
||||
TimeInspector.set_time_mark()
|
||||
report_normal_pkl_path = os.path.join(result_dir, "report_normal.pkl")
|
||||
report_normal.to_pickle(report_normal_pkl_path)
|
||||
self.ex.add_artifact(report_normal_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving report_normal.pkl to: {}".format(report_normal_pkl_path))
|
||||
|
||||
# 6. Report long short.
|
||||
# Deprecated
|
||||
# for k, name in zip(
|
||||
# ["long", "short", "long_short"],
|
||||
# ["report_long.pkl", "report_short.pkl", "report_long_short.pkl"],
|
||||
# ):
|
||||
# TimeInspector.set_time_mark()
|
||||
# pkl_path = os.path.join(result_dir, name)
|
||||
# long_short_reports[k].to_pickle(pkl_path)
|
||||
# self.ex.add_artifact(pkl_path)
|
||||
# TimeInspector.log_cost_time("Finished saving {} to: {}".format(name, pkl_path))
|
||||
|
||||
# 7. Origin test label.
|
||||
TimeInspector.set_time_mark()
|
||||
label_pkl_path = os.path.join(result_dir, "label.pkl")
|
||||
self.data_handler.get_origin_test_label_with_date(
|
||||
self.trainer_config.parameters["test_start_date"],
|
||||
self.trainer_config.parameters["test_end_date"],
|
||||
).to_pickle(label_pkl_path)
|
||||
self.ex.add_artifact(label_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving label.pkl to: {}".format(label_pkl_path))
|
||||
|
||||
# 8. Experiment info, save the model(s) performance here.
|
||||
TimeInspector.set_time_mark()
|
||||
cur_ex_id = self.ex.experiment.current_run._id
|
||||
exp_info = {
|
||||
"id": cur_ex_id,
|
||||
"name": self.ex_config.name,
|
||||
"performance": performance,
|
||||
"observer_type": self.ex_config.observer_type,
|
||||
}
|
||||
|
||||
if self.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
|
||||
exp_info.update(
|
||||
{
|
||||
"mongo_url": self.ex_config.mongo_url,
|
||||
"db_name": self.ex_config.db_name,
|
||||
}
|
||||
)
|
||||
else:
|
||||
exp_info.update({"dir": self.ex_config.global_dir})
|
||||
|
||||
with open(self.ex_config.exp_info_path, "w") as fp:
|
||||
json.dump(exp_info, fp, indent=4, sort_keys=True)
|
||||
self.ex.add_artifact(self.ex_config.exp_info_path)
|
||||
TimeInspector.log_cost_time("Finished saving ex_info to: {}".format(self.ex_config.exp_info_path))
|
||||
|
||||
@staticmethod
|
||||
def compare_config_with_config_manger(config_manager):
|
||||
"""Compare loader model args and current config with ConfigManage
|
||||
|
||||
:param config_manager: ConfigManager
|
||||
:return:
|
||||
"""
|
||||
fetcher = create_fetcher_with_config(config_manager, load_form_loader=True)
|
||||
loader_mode_config = fetcher.get_experiment(
|
||||
exp_name=config_manager.ex_config.loader_name,
|
||||
exp_id=config_manager.ex_config.loader_id,
|
||||
fields=["task_config"],
|
||||
)["task_config"]
|
||||
with open(config_manager.config_path) as fp:
|
||||
current_config = yaml.load(fp.read())
|
||||
current_config = json.loads(json.dumps(current_config, default=str))
|
||||
|
||||
logger = get_module_logger("Estimator")
|
||||
|
||||
loader_mode_config = copy.deepcopy(loader_mode_config)
|
||||
current_config = copy.deepcopy(current_config)
|
||||
|
||||
# Require test_mode_config.test_start_date <= current_config.test_start_date
|
||||
loader_trainer_args = loader_mode_config.get("trainer", {}).get("args", {})
|
||||
cur_trainer_args = current_config.get("trainer", {}).get("args", {})
|
||||
loader_start_date = loader_trainer_args.pop("test_start_date")
|
||||
cur_test_start_date = cur_trainer_args.pop("test_start_date")
|
||||
assert (
|
||||
loader_start_date <= cur_test_start_date
|
||||
), "Require: loader_mode_config.test_start_date <= current_config.test_start_date"
|
||||
|
||||
# TODO: For the user's own extended `Trainer`, the support is not very good
|
||||
if "RollingTrainer" == current_config.get("trainer", {}).get("class", None):
|
||||
loader_period = loader_trainer_args.pop("rolling_period")
|
||||
cur_period = cur_trainer_args.pop("rolling_period")
|
||||
assert (
|
||||
loader_period == cur_period
|
||||
), "Require: loader_mode_config.rolling_period == current_config.rolling_period"
|
||||
|
||||
compare_section = ["trainer", "model", "data"]
|
||||
for section in compare_section:
|
||||
changes = compare_dict_value(loader_mode_config.get(section, {}), current_config.get(section, {}))
|
||||
if changes:
|
||||
logger.warning("Warning: Loader mode config and current config, `{}` are different:\n".format(section))
|
||||
@@ -1,290 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
import copy
|
||||
import json
|
||||
import yaml
|
||||
import pickle
|
||||
import gridfs
|
||||
import pymongo
|
||||
from pathlib import Path
|
||||
from abc import abstractmethod
|
||||
|
||||
from .config import EstimatorConfigManager, ExperimentConfig
|
||||
|
||||
|
||||
class Fetcher(object):
|
||||
"""Sacred Experiments Fetcher"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_experiment(self, exp_name, exp_id):
|
||||
"""Get experiment basic info with experiment and experiment id
|
||||
|
||||
:param exp_name: experiment name
|
||||
:param exp_id: experiment id
|
||||
:return: dict
|
||||
Must contain keys: _id, experiment, info, stop_time.
|
||||
Here is an example below for FileFetcher.
|
||||
exp = {
|
||||
'_id': exp_id, # experiment id
|
||||
'path': path, # experiment result path
|
||||
'experiment': {'name': exp_name}, # experiment
|
||||
'info': info, # experiment config info
|
||||
'stop_time': run.get('stop_time', None) # The time the experiment ended
|
||||
}
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _list_experiments(self, exp_name=None):
|
||||
"""Get experiment basic info list with experiment name
|
||||
|
||||
:param exp_name: experiment name
|
||||
:return: list
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _iter_artifacts(self, experiment):
|
||||
"""Get information about the data in the experiment results
|
||||
|
||||
:param experiment: `self._get_experiment` method result
|
||||
:return: iterable
|
||||
Each element contains two elements.
|
||||
first element : data name
|
||||
second element : data uri
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _load_data(self, uri):
|
||||
"""Load data with uri
|
||||
|
||||
:param uri: data uri
|
||||
:return: bytes
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def model_dict_to_buffer_list(model_dict):
|
||||
"""
|
||||
|
||||
:param model_dict:
|
||||
:return:
|
||||
"""
|
||||
model_list = []
|
||||
is_static_model = False
|
||||
if len(model_dict) == 1 and list(model_dict.keys())[0] == "model.bin":
|
||||
is_static_model = True
|
||||
model_list.append(list(model_dict.values())[0])
|
||||
else:
|
||||
sep = "model.bin_"
|
||||
model_ids = list(map(lambda x: int(x.split(sep)[1]), model_dict.keys()))
|
||||
min_id, max_id = min(model_ids), max(model_ids)
|
||||
for i in range(min_id, max_id + 1):
|
||||
model_key = sep + str(i)
|
||||
model = model_dict.get(model_key, None)
|
||||
if model is None:
|
||||
print(
|
||||
"WARNING: In Fetcher, {} is missing when the get model is in the get_experiment function.".format(
|
||||
model_key
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
model_list.append(model)
|
||||
|
||||
if is_static_model:
|
||||
return model_list[0]
|
||||
|
||||
return model_list
|
||||
|
||||
def get_experiments(self, exp_name=None):
|
||||
"""Get experiments with name.
|
||||
|
||||
:param exp_name: str
|
||||
If `exp_name` is set to None, then all experiments will return.
|
||||
:return: dict
|
||||
Experiments info dict(Including experiment id and task_config to run the
|
||||
experiment). Here is an example below.
|
||||
{
|
||||
'a_experiment': [
|
||||
{
|
||||
'id': '1',
|
||||
'task_config': {...}
|
||||
},
|
||||
...
|
||||
]
|
||||
...
|
||||
}
|
||||
"""
|
||||
res = dict()
|
||||
for ex in self._list_experiments(exp_name):
|
||||
name = ex["experiment"]["name"]
|
||||
tmp = {
|
||||
"id": ex["_id"],
|
||||
"task_config": ex["info"].get("task_config", {}),
|
||||
"ex_run_stop_time": ex.get("stop_time", None),
|
||||
}
|
||||
res.setdefault(name, []).append(tmp)
|
||||
return res
|
||||
|
||||
def get_experiment(self, exp_name, exp_id, fields=None):
|
||||
"""
|
||||
|
||||
:param exp_name:
|
||||
:param exp_id:
|
||||
:param fields: list
|
||||
Experiment result fields, if fields is None, will get all fields.
|
||||
Currently supported fields:
|
||||
['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
|
||||
:return: dict
|
||||
"""
|
||||
fields = copy.copy(fields)
|
||||
ex = self._get_experiment(exp_name, exp_id)
|
||||
results = dict()
|
||||
model_dict = dict()
|
||||
for name, uri in self._iter_artifacts(ex):
|
||||
# When saving, use `sacred.experiment.add_artifact(filename)` , so `name` is os.path.basename(filename)
|
||||
prefix = name.split(".")[0]
|
||||
if fields and prefix not in fields:
|
||||
continue
|
||||
data = self._load_data(uri)
|
||||
if prefix == "model":
|
||||
model_dict[name] = data
|
||||
else:
|
||||
results[prefix] = pickle.loads(data)
|
||||
# Sort model
|
||||
if model_dict:
|
||||
results["model"] = self.model_dict_to_buffer_list(model_dict)
|
||||
|
||||
# Info
|
||||
results["task_config"] = ex["info"].get("task_config", {})
|
||||
return results
|
||||
|
||||
def estimator_config_to_dict(self, exp_name, exp_id):
|
||||
"""Save configuration to file
|
||||
|
||||
:param exp_name:
|
||||
:param exp_id:
|
||||
:return: config dict
|
||||
"""
|
||||
|
||||
return self.get_experiment(exp_name, exp_id, fields=["task_config"])["task_config"]
|
||||
|
||||
|
||||
class FileFetcher(Fetcher):
|
||||
"""File Fetcher"""
|
||||
|
||||
def __init__(self, experiments_dir):
|
||||
self.experiments_dir = Path(experiments_dir)
|
||||
|
||||
def _get_experiment(self, exp_name, exp_id):
|
||||
path = self.experiments_dir / exp_name / "sacred" / str(exp_id)
|
||||
info_path = path / "info.json"
|
||||
run_path = path / "run.json"
|
||||
|
||||
if info_path.exists():
|
||||
with info_path.open("r") as f:
|
||||
info = json.load(f)
|
||||
else:
|
||||
info = {}
|
||||
|
||||
if run_path.exists():
|
||||
with run_path.open("r") as f:
|
||||
run = json.load(f)
|
||||
else:
|
||||
run = {}
|
||||
|
||||
exp = {
|
||||
"_id": exp_id,
|
||||
"path": path,
|
||||
"experiment": {"name": exp_name},
|
||||
"info": info,
|
||||
"stop_time": run.get("stop_time", None),
|
||||
}
|
||||
return exp
|
||||
|
||||
def _list_experiments(self, exp_name=None):
|
||||
runs = []
|
||||
for path in self.experiments_dir.glob("{}/sacred/[!_]*".format(exp_name or "*")):
|
||||
exp_name, exp_id = path.parents[1].name, path.name
|
||||
runs.append(self._get_experiment(exp_name, exp_id))
|
||||
return runs
|
||||
|
||||
def _iter_artifacts(self, experiment):
|
||||
if experiment is None:
|
||||
return []
|
||||
|
||||
for fname in experiment["path"].iterdir():
|
||||
if fname.suffix == ".pkl" or ".bin" in fname.suffix:
|
||||
name, uri = fname.name, str(fname)
|
||||
yield name, uri
|
||||
|
||||
def _load_data(self, uri):
|
||||
with open(uri, "rb") as f:
|
||||
data = f.read()
|
||||
return data
|
||||
|
||||
|
||||
class MongoFetcher(Fetcher):
|
||||
"""MongoDB Fetcher"""
|
||||
|
||||
def __init__(self, mongo_url, db_name):
|
||||
self.mongo_url = mongo_url
|
||||
self.db_name = db_name
|
||||
self.client = None
|
||||
self.db = None
|
||||
self.runs = None
|
||||
self.fs = None
|
||||
self._setup_mongo_client()
|
||||
|
||||
def _setup_mongo_client(self):
|
||||
self.client = pymongo.MongoClient(self.mongo_url)
|
||||
self.db = self.client[self.db_name]
|
||||
self.runs = self.db.runs
|
||||
self.fs = gridfs.GridFS(self.db)
|
||||
|
||||
def _get_experiment(self, exp_name, exp_id):
|
||||
return self.runs.find_one({"_id": exp_id})
|
||||
|
||||
def _list_experiments(self, exp_name=None):
|
||||
if exp_name is None:
|
||||
return self.runs.find()
|
||||
return self.runs.find({"experiment.name": exp_name})
|
||||
|
||||
def _iter_artifacts(self, experiment):
|
||||
if experiment is None:
|
||||
return []
|
||||
for artifact in experiment.get("artifacts", []):
|
||||
name, uri = artifact["name"], artifact["file_id"]
|
||||
yield name, uri
|
||||
|
||||
def _load_data(self, uri):
|
||||
data = self.fs.get(uri).read()
|
||||
return data
|
||||
|
||||
|
||||
def create_fetcher_with_config(config_manager: EstimatorConfigManager, load_form_loader: bool = False):
|
||||
"""Create fetcher with loader config
|
||||
|
||||
:param config_manager:
|
||||
:param load_form_loader
|
||||
:return:
|
||||
"""
|
||||
flag = ""
|
||||
if load_form_loader:
|
||||
flag = "loader_"
|
||||
if config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_FILE_STORAGE:
|
||||
return FileFetcher(eval("config_manager.ex_config.{}_dir".format("loader" if load_form_loader else "global")))
|
||||
elif config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
|
||||
return MongoFetcher(
|
||||
mongo_url=eval("config_manager.ex_config.{}mongo_url".format(flag)),
|
||||
db_name=eval("config_manager.ex_config.{}db_name".format(flag)),
|
||||
)
|
||||
else:
|
||||
return NotImplementedError("Unkown Backend")
|
||||
@@ -1,585 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
import abc
|
||||
import bisect
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...data import D
|
||||
from ...utils import parse_config, transform_end_date
|
||||
|
||||
from . import processor as processor_module
|
||||
|
||||
|
||||
class BaseDataHandler(abc.ABC):
|
||||
def __init__(self, processors=[], **kwargs):
|
||||
"""
|
||||
:param start_date:
|
||||
:param end_date:
|
||||
:param kwargs:
|
||||
"""
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
|
||||
# init data using kwargs
|
||||
self._init_kwargs(**kwargs)
|
||||
|
||||
# Setup data.
|
||||
self.raw_df, self.feature_names, self.label_names = self._init_raw_df()
|
||||
|
||||
# Setup preprocessor
|
||||
self.processors = []
|
||||
for klass in processors:
|
||||
if isinstance(klass, str):
|
||||
try:
|
||||
klass = getattr(processor_module, klass)
|
||||
except:
|
||||
raise ValueError("unknown Processor %s" % klass)
|
||||
self.processors.append(klass(self.feature_names, self.label_names, **kwargs))
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
"""
|
||||
init the kwargs of DataHandler
|
||||
"""
|
||||
pass
|
||||
|
||||
def _init_raw_df(self):
|
||||
"""
|
||||
init raw_df, feature_names, label_names of DataHandler
|
||||
if the index of df_feature and df_label are not same, user need to overload this method to merge (e.g. inner, left, right merge).
|
||||
|
||||
"""
|
||||
df_features = self.setup_feature()
|
||||
feature_names = df_features.columns
|
||||
|
||||
df_labels = self.setup_label()
|
||||
label_names = df_labels.columns
|
||||
|
||||
raw_df = df_features.merge(df_labels, left_index=True, right_index=True, how="left")
|
||||
|
||||
return raw_df, feature_names, label_names
|
||||
|
||||
def reset_label(self, df_labels):
|
||||
for col in self.label_names:
|
||||
del self.raw_df[col]
|
||||
self.label_names = df_labels.columns
|
||||
self.raw_df = self.raw_df.merge(df_labels, left_index=True, right_index=True, how="left")
|
||||
|
||||
def split_rolling_periods(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq="day",
|
||||
):
|
||||
"""
|
||||
Calculating the Rolling split periods, the period rolling on market calendar.
|
||||
:param train_start_date:
|
||||
:param train_end_date:
|
||||
:param validate_start_date:
|
||||
:param validate_end_date:
|
||||
:param test_start_date:
|
||||
:param test_end_date:
|
||||
:param rolling_period: The market period of rolling
|
||||
:param calendar_freq: The frequence of the market calendar
|
||||
:yield: Rolling split periods
|
||||
"""
|
||||
|
||||
def get_start_index(calendar, start_date):
|
||||
start_index = bisect.bisect_left(calendar, start_date)
|
||||
return start_index
|
||||
|
||||
def get_end_index(calendar, end_date):
|
||||
end_index = bisect.bisect_right(calendar, end_date)
|
||||
return end_index - 1
|
||||
|
||||
calendar = self.raw_df.index.get_level_values("datetime").unique()
|
||||
|
||||
train_start_index = get_start_index(calendar, pd.Timestamp(train_start_date))
|
||||
train_end_index = get_end_index(calendar, pd.Timestamp(train_end_date))
|
||||
valid_start_index = get_start_index(calendar, pd.Timestamp(validate_start_date))
|
||||
valid_end_index = get_end_index(calendar, pd.Timestamp(validate_end_date))
|
||||
test_start_index = get_start_index(calendar, pd.Timestamp(test_start_date))
|
||||
test_end_index = test_start_index + rolling_period - 1
|
||||
|
||||
need_stop_split = False
|
||||
|
||||
bound_test_end_index = get_end_index(calendar, pd.Timestamp(test_end_date))
|
||||
|
||||
while not need_stop_split:
|
||||
|
||||
if test_end_index > bound_test_end_index:
|
||||
test_end_index = bound_test_end_index
|
||||
need_stop_split = True
|
||||
|
||||
yield (
|
||||
calendar[train_start_index],
|
||||
calendar[train_end_index],
|
||||
calendar[valid_start_index],
|
||||
calendar[valid_end_index],
|
||||
calendar[test_start_index],
|
||||
calendar[test_end_index],
|
||||
)
|
||||
|
||||
train_start_index += rolling_period
|
||||
train_end_index += rolling_period
|
||||
valid_start_index += rolling_period
|
||||
valid_end_index += rolling_period
|
||||
test_start_index += rolling_period
|
||||
test_end_index += rolling_period
|
||||
|
||||
def get_rolling_data(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq="day",
|
||||
):
|
||||
# Set generator.
|
||||
for period in self.split_rolling_periods(
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
rolling_period,
|
||||
calendar_freq,
|
||||
):
|
||||
(
|
||||
x_train,
|
||||
y_train,
|
||||
x_validate,
|
||||
y_validate,
|
||||
x_test,
|
||||
y_test,
|
||||
) = self.get_split_data(*period)
|
||||
yield x_train, y_train, x_validate, y_validate, x_test, y_test
|
||||
|
||||
def get_split_data(
|
||||
self,
|
||||
train_start_date,
|
||||
train_end_date,
|
||||
validate_start_date,
|
||||
validate_end_date,
|
||||
test_start_date,
|
||||
test_end_date,
|
||||
):
|
||||
"""
|
||||
all return types are DataFrame
|
||||
"""
|
||||
## TODO: loc can be slow, expecially when we put it at the second level index.
|
||||
if self.raw_df.index.names[0] == "instrument":
|
||||
df_train = self.raw_df.loc(axis=0)[:, train_start_date:train_end_date]
|
||||
df_validate = self.raw_df.loc(axis=0)[:, validate_start_date:validate_end_date]
|
||||
df_test = self.raw_df.loc(axis=0)[:, test_start_date:test_end_date]
|
||||
else:
|
||||
df_train = self.raw_df.loc[train_start_date:train_end_date]
|
||||
df_validate = self.raw_df.loc[validate_start_date:validate_end_date]
|
||||
df_test = self.raw_df.loc[test_start_date:test_end_date]
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
df_train, df_validate, df_test = self.setup_process_data(df_train, df_validate, df_test)
|
||||
TimeInspector.log_cost_time("Finished setup processed data.")
|
||||
|
||||
x_train = df_train[self.feature_names]
|
||||
y_train = df_train[self.label_names]
|
||||
|
||||
x_validate = df_validate[self.feature_names]
|
||||
y_validate = df_validate[self.label_names]
|
||||
|
||||
x_test = df_test[self.feature_names]
|
||||
y_test = df_test[self.label_names]
|
||||
|
||||
return x_train, y_train, x_validate, y_validate, x_test, y_test
|
||||
|
||||
def setup_process_data(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
process the train, valid and test data
|
||||
:return: the processed train, valid and test data.
|
||||
"""
|
||||
for processor in self.processors:
|
||||
df_train, df_valid, df_test = processor(df_train, df_valid, df_test)
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def get_origin_test_label_with_date(self, test_start_date, test_end_date, freq="day"):
|
||||
"""Get origin test label
|
||||
|
||||
:param test_start_date: test start date
|
||||
:param test_end_date: test end date
|
||||
:param freq: freq
|
||||
:return: pd.DataFrame
|
||||
"""
|
||||
test_end_date = transform_end_date(test_end_date, freq=freq)
|
||||
return self.raw_df.loc[(slice(None), slice(test_start_date, test_end_date)), self.label_names]
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_feature(self):
|
||||
"""
|
||||
Implement this method to load raw feature.
|
||||
the format of the feature is below
|
||||
return: df_features
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup_label(self):
|
||||
"""
|
||||
Implement this method to load and calculate label.
|
||||
the format of the label is below
|
||||
|
||||
return: df_label
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class QLibDataHandler(BaseDataHandler):
|
||||
def __init__(self, start_date, end_date, *args, **kwargs):
|
||||
# Dates.
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
|
||||
# Instruments
|
||||
instruments = kwargs.get("instruments", None)
|
||||
if instruments is None:
|
||||
market = kwargs.get("market", "csi500").lower()
|
||||
data_filter_list = kwargs.get("data_filter_list", list())
|
||||
self.instruments = D.instruments(market, filter_pipe=data_filter_list)
|
||||
else:
|
||||
self.instruments = instruments
|
||||
|
||||
# Config of features and labels
|
||||
self._fields = kwargs.get("fields", [])
|
||||
self._names = kwargs.get("names", [])
|
||||
self._labels = kwargs.get("labels", [])
|
||||
self._label_names = kwargs.get("label_names", [])
|
||||
|
||||
# Check arguments
|
||||
assert len(self._fields) > 0, "features list is empty"
|
||||
assert len(self._labels) > 0, "labels list is empty"
|
||||
|
||||
# Check end_date
|
||||
# If test_end_date is -1 or greater than the last date, the last date is used
|
||||
self.end_date = transform_end_date(self.end_date)
|
||||
|
||||
def setup_feature(self):
|
||||
"""
|
||||
Load the raw data.
|
||||
return: df_features
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
if len(self._names) == 0:
|
||||
names = ["F%d" % i for i in range(len(self._fields))]
|
||||
else:
|
||||
names = self._names
|
||||
|
||||
df_features = D.features(self.instruments, self._fields, self.start_date, self.end_date)
|
||||
df_features.columns = names
|
||||
|
||||
TimeInspector.log_cost_time("Finished loading features.")
|
||||
|
||||
return df_features
|
||||
|
||||
def setup_label(self):
|
||||
"""
|
||||
Build up labels in df through users' method
|
||||
:return: df_labels
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
if len(self._label_names) == 0:
|
||||
label_names = ["LABEL%d" % i for i in range(len(self._labels))]
|
||||
else:
|
||||
label_names = self._label_names
|
||||
|
||||
df_labels = D.features(self.instruments, self._labels, self.start_date, self.end_date)
|
||||
df_labels.columns = label_names
|
||||
|
||||
TimeInspector.log_cost_time("Finished loading labels.")
|
||||
|
||||
return df_labels
|
||||
|
||||
|
||||
def parse_config_to_fields(config):
|
||||
"""create factors from config
|
||||
|
||||
config = {
|
||||
'kbar': {}, # whether to use some hard-code kbar features
|
||||
'price': { # whether to use raw price features
|
||||
'windows': [0, 1, 2, 3, 4], # use price at n days ago
|
||||
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
|
||||
},
|
||||
'volume': { # whether to use raw volume features
|
||||
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
|
||||
},
|
||||
'rolling': { # whether to use rolling operator based features
|
||||
'windows': [5, 10, 20, 30, 60], # rolling windows size
|
||||
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
|
||||
#if include is None we will use default operators
|
||||
'exclude': ['RANK'], # rolling operator not to use
|
||||
}
|
||||
}
|
||||
"""
|
||||
fields = []
|
||||
names = []
|
||||
if "kbar" in config:
|
||||
fields += [
|
||||
"($close-$open)/$open",
|
||||
"($high-$low)/$open",
|
||||
"($close-$open)/($high-$low+1e-12)",
|
||||
"($high-Greater($open, $close))/$open",
|
||||
"($high-Greater($open, $close))/($high-$low+1e-12)",
|
||||
"(Less($open, $close)-$low)/$open",
|
||||
"(Less($open, $close)-$low)/($high-$low+1e-12)",
|
||||
"(2*$close-$high-$low)/$open",
|
||||
"(2*$close-$high-$low)/($high-$low+1e-12)",
|
||||
]
|
||||
names += [
|
||||
"KMID",
|
||||
"KLEN",
|
||||
"KMID2",
|
||||
"KUP",
|
||||
"KUP2",
|
||||
"KLOW",
|
||||
"KLOW2",
|
||||
"KSFT",
|
||||
"KSFT2",
|
||||
]
|
||||
if "price" in config:
|
||||
windows = config["price"].get("windows", range(5))
|
||||
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
|
||||
for field in feature:
|
||||
field = field.lower()
|
||||
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
|
||||
names += [field.upper() + str(d) for d in windows]
|
||||
if "volume" in config:
|
||||
windows = config["volume"].get("windows", range(5))
|
||||
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
|
||||
names += ["VOLUME" + str(d) for d in windows]
|
||||
if "rolling" in config:
|
||||
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
|
||||
include = config["rolling"].get("include", None)
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
use = lambda x: x not in exclude and (include is None or x in include)
|
||||
if use("ROC"):
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
if use("MA"):
|
||||
fields += ["Mean($close, %d)/$close" % d for d in windows]
|
||||
names += ["MA%d" % d for d in windows]
|
||||
if use("STD"):
|
||||
fields += ["Std($close, %d)/$close" % d for d in windows]
|
||||
names += ["STD%d" % d for d in windows]
|
||||
if use("BETA"):
|
||||
fields += ["Slope($close, %d)/$close" % d for d in windows]
|
||||
names += ["BETA%d" % d for d in windows]
|
||||
if use("RSQR"):
|
||||
fields += ["Rsquare($close, %d)" % d for d in windows]
|
||||
names += ["RSQR%d" % d for d in windows]
|
||||
if use("RESI"):
|
||||
fields += ["Resi($close, %d)/$close" % d for d in windows]
|
||||
names += ["RESI%d" % d for d in windows]
|
||||
if use("MAX"):
|
||||
fields += ["Max($high, %d)/$close" % d for d in windows]
|
||||
names += ["MAX%d" % d for d in windows]
|
||||
if use("LOW"):
|
||||
fields += ["Min($low, %d)/$close" % d for d in windows]
|
||||
names += ["MIN%d" % d for d in windows]
|
||||
if use("QTLU"):
|
||||
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
|
||||
names += ["QTLU%d" % d for d in windows]
|
||||
if use("QTLD"):
|
||||
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
|
||||
names += ["QTLD%d" % d for d in windows]
|
||||
if use("RANK"):
|
||||
fields += ["Rank($close, %d)" % d for d in windows]
|
||||
names += ["RANK%d" % d for d in windows]
|
||||
if use("RSV"):
|
||||
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
|
||||
names += ["RSV%d" % d for d in windows]
|
||||
if use("IMAX"):
|
||||
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMAX%d" % d for d in windows]
|
||||
if use("IMIN"):
|
||||
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
|
||||
names += ["IMIN%d" % d for d in windows]
|
||||
if use("IMXD"):
|
||||
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
|
||||
names += ["IMXD%d" % d for d in windows]
|
||||
if use("CORR"):
|
||||
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
|
||||
names += ["CORR%d" % d for d in windows]
|
||||
if use("CORD"):
|
||||
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
|
||||
names += ["CORD%d" % d for d in windows]
|
||||
if use("CNTP"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTP%d" % d for d in windows]
|
||||
if use("CNTN"):
|
||||
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
|
||||
names += ["CNTN%d" % d for d in windows]
|
||||
if use("CNTD"):
|
||||
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
|
||||
names += ["CNTD%d" % d for d in windows]
|
||||
if use("SUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMP%d" % d for d in windows]
|
||||
if use("SUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMN%d" % d for d in windows]
|
||||
if use("SUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
|
||||
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["SUMD%d" % d for d in windows]
|
||||
if use("VMA"):
|
||||
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VMA%d" % d for d in windows]
|
||||
if use("VSTD"):
|
||||
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
|
||||
names += ["VSTD%d" % d for d in windows]
|
||||
if use("WVMA"):
|
||||
fields += [
|
||||
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
|
||||
% (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["WVMA%d" % d for d in windows]
|
||||
if use("VSUMP"):
|
||||
fields += [
|
||||
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMP%d" % d for d in windows]
|
||||
if use("VSUMN"):
|
||||
fields += [
|
||||
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMN%d" % d for d in windows]
|
||||
if use("VSUMD"):
|
||||
fields += [
|
||||
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
|
||||
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
|
||||
for d in windows
|
||||
]
|
||||
names += ["VSUMD%d" % d for d in windows]
|
||||
|
||||
return fields, names
|
||||
|
||||
|
||||
class ConfigQLibDataHandler(QLibDataHandler):
|
||||
config_template = {} # template
|
||||
|
||||
def __init__(self, start_date, end_date, processors=None, **kwargs):
|
||||
if processors is None:
|
||||
processors = ["ConfigSectionProcessor"] # default processor
|
||||
super().__init__(start_date, end_date, processors, **kwargs)
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
config = self.config_template.copy()
|
||||
if "config_update" in kwargs:
|
||||
config.update(kwargs["config_update"])
|
||||
fields, names = parse_config_to_fields(config)
|
||||
kwargs["fields"] = fields
|
||||
kwargs["names"] = names
|
||||
if "labels" not in kwargs:
|
||||
kwargs["labels"] = ["Ref($vwap, -2)/Ref($vwap, -1) - 1"]
|
||||
super()._init_kwargs(**kwargs)
|
||||
|
||||
|
||||
class ALPHA360(ConfigQLibDataHandler):
|
||||
config_template = {
|
||||
"price": {"windows": range(60)},
|
||||
"volume": {"windows": range(60)},
|
||||
}
|
||||
|
||||
|
||||
class QLibDataHandlerV1(ConfigQLibDataHandler):
|
||||
config_template = {
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
|
||||
def __init__(self, start_date, end_date, processors=None, **kwargs):
|
||||
if processors is None:
|
||||
processors = ["PanelProcessor"] # V1 default processor
|
||||
super().__init__(start_date, end_date, processors, **kwargs)
|
||||
|
||||
def setup_label(self):
|
||||
"""
|
||||
load the labels df
|
||||
:return: df_labels
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
df_labels = super().setup_label()
|
||||
|
||||
## calculate new labels
|
||||
df_labels["LABEL1"] = df_labels["LABEL0"].groupby(level="datetime").apply(lambda x: (x - x.mean()) / x.std())
|
||||
|
||||
df_labels = df_labels.drop(["LABEL0"], axis=1)
|
||||
|
||||
TimeInspector.log_cost_time("Finished loading labels.")
|
||||
|
||||
return df_labels
|
||||
|
||||
|
||||
class Alpha158(QLibDataHandlerV1):
|
||||
config_template = {
|
||||
"kbar": {},
|
||||
"price": {
|
||||
"windows": [0],
|
||||
"feature": ["OPEN", "HIGH", "LOW", "CLOSE"],
|
||||
},
|
||||
"rolling": {},
|
||||
}
|
||||
|
||||
def _init_kwargs(self, **kwargs):
|
||||
kwargs["labels"] = ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
super(Alpha158, self)._init_kwargs(**kwargs)
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# import qlib
|
||||
#
|
||||
# qlib.init()
|
||||
#
|
||||
# handler = ALPHA80('2010-01-01', '2018-12-31')
|
||||
# data = handler.get_split_data(
|
||||
# pd.Timestamp('2010-01-01'), pd.Timestamp('2014-01-01'),
|
||||
# pd.Timestamp('2015-01-01'), pd.Timestamp('2016-01-01'),
|
||||
# pd.Timestamp('2017-01-01'), pd.Timestamp('2018-01-01'))
|
||||
# print(data[0])
|
||||
# data[0].to_pickle('alpha80.pkl')
|
||||
@@ -1,115 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
from ... import init
|
||||
from .config import EstimatorConfigManager
|
||||
from ...log import get_module_logger
|
||||
from sacred import Experiment
|
||||
from sacred.observers import FileStorageObserver
|
||||
from sacred.observers import MongoObserver
|
||||
|
||||
args_parser = argparse.ArgumentParser(prog="estimator")
|
||||
args_parser.add_argument(
|
||||
"-c",
|
||||
"--config_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="json config path indicates where to load config.",
|
||||
)
|
||||
|
||||
args = args_parser.parse_args()
|
||||
|
||||
|
||||
class SacredExperiment(object):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name,
|
||||
experiment_dir,
|
||||
observer_type="file_storage",
|
||||
mongo_url=None,
|
||||
db_name=None,
|
||||
):
|
||||
"""__init__
|
||||
|
||||
:param experiment_name: The name of the experiments.
|
||||
:param experiment_dir: The directory to store all the results of the experiments(This is for file_storage).
|
||||
:param observer_type: The observer to record the results: the `file_storage` or `mongo`
|
||||
:param mongo_url: The mongo url(for mongo observer)
|
||||
:param db_name: The mongo url(for mongo observer)
|
||||
"""
|
||||
self.experiment_name = experiment_name
|
||||
self.experiment = Experiment(self.experiment_name)
|
||||
self.experiment_dir = experiment_dir
|
||||
self.experiment.logger = get_module_logger("Sacred")
|
||||
|
||||
self.observer_type = observer_type
|
||||
self.mongo_db_url = mongo_url
|
||||
self.mongo_db_name = db_name
|
||||
|
||||
self._setup_experiment()
|
||||
|
||||
def _setup_experiment(self):
|
||||
if self.observer_type == "file_storage":
|
||||
file_storage_observer = FileStorageObserver.create(basedir=self.experiment_dir)
|
||||
self.experiment.observers.append(file_storage_observer)
|
||||
elif self.observer_type == "mongo":
|
||||
mongo_observer = MongoObserver.create(url=self.mongo_db_url, db_name=self.mongo_db_name)
|
||||
self.experiment.observers.append(mongo_observer)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported observer type: {}".format(self.observer_type))
|
||||
|
||||
def add_artifact(self, filename):
|
||||
self.experiment.add_artifact(filename)
|
||||
|
||||
def add_info(self, key, value):
|
||||
self.experiment.info[key] = value
|
||||
|
||||
def main_wrapper(self, func):
|
||||
return self.experiment.main(func)
|
||||
|
||||
def config_wrapper(self, func):
|
||||
return self.experiment.config(func)
|
||||
|
||||
|
||||
CONFIG_MANAGER = EstimatorConfigManager(args.config_path)
|
||||
|
||||
ex = SacredExperiment(
|
||||
CONFIG_MANAGER.ex_config.name,
|
||||
CONFIG_MANAGER.ex_config.sacred_dir,
|
||||
observer_type=CONFIG_MANAGER.ex_config.observer_type,
|
||||
mongo_url=CONFIG_MANAGER.ex_config.mongo_url,
|
||||
db_name=CONFIG_MANAGER.ex_config.db_name,
|
||||
)
|
||||
|
||||
# qlib init
|
||||
init(
|
||||
provider_uri=CONFIG_MANAGER.qlib_data_config.provider_uri,
|
||||
mount_path=CONFIG_MANAGER.qlib_data_config.mount_path,
|
||||
auto_mount=CONFIG_MANAGER.qlib_data_config.auto_mount,
|
||||
region=CONFIG_MANAGER.qlib_data_config.region,
|
||||
**CONFIG_MANAGER.qlib_data_config.args
|
||||
)
|
||||
|
||||
|
||||
@ex.main_wrapper
|
||||
def _main():
|
||||
# 1. Get estimator class.
|
||||
estimator_class = getattr(
|
||||
importlib.import_module(".estimator", package="qlib.contrib.estimator"),
|
||||
"Estimator",
|
||||
)
|
||||
# 2. Init estimator.
|
||||
estimator = estimator_class(CONFIG_MANAGER, ex)
|
||||
estimator.run()
|
||||
|
||||
|
||||
def run():
|
||||
ex.experiment.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -1,249 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...log import TimeInspector
|
||||
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
class Processor(abc.ABC):
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
self.feature_names = feature_names
|
||||
self.label_names = label_names
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df_train, df_valid, df_test):
|
||||
pass
|
||||
|
||||
|
||||
class PanelProcessor(Processor):
|
||||
"""Panel Preprocessor"""
|
||||
|
||||
STD_NORM = "Std"
|
||||
MINMAX_NORM = "MinMax"
|
||||
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
super().__init__(feature_names, label_names)
|
||||
# Options.
|
||||
self.dropna_label = kwargs.get("dropna_label", True)
|
||||
self.dropna_feature = kwargs.get("dropna_feature", False)
|
||||
self.normalize_method = kwargs.get("normalize_method", None)
|
||||
self.replace_inf = kwargs.get("replace_inf_feature", False)
|
||||
|
||||
def __call__(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Preprocess the data
|
||||
:param df: the dataframe to process data.
|
||||
"""
|
||||
# Drop null labels.
|
||||
if self.dropna_label:
|
||||
df_train, df_valid, df_test = self._process_drop_null_label(df_train, df_valid, df_test)
|
||||
|
||||
# Dropna if need.
|
||||
if self.dropna_feature:
|
||||
df_train, df_valid, df_test = self._process_drop_null_feature(df_train, df_valid, df_test)
|
||||
|
||||
# replace the 'inf' with the mean the corresponding dimension
|
||||
if self.replace_inf:
|
||||
df_train, df_valid, df_test = self._process_replace_inf_feature(df_train, df_valid, df_test)
|
||||
|
||||
# normalize data in given method.
|
||||
if self.normalize_method is not None:
|
||||
df_train, df_valid, df_test = self._process_normalize_feature(df_train, df_valid, df_test)
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_drop_null_label(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Drop null labels.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
df_train = df_train.dropna(subset=self.label_names)
|
||||
df_valid = df_valid.dropna(subset=self.label_names)
|
||||
# The test data's label is Unkown. They can not be seen when preprocessing
|
||||
TimeInspector.log_cost_time("Finished dropping null labels.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_drop_null_feature(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Drop data which contain null features if needed.
|
||||
"""
|
||||
# TODO - `Pandas.dropna` is a low performance method.
|
||||
TimeInspector.set_time_mark()
|
||||
df_train = df_train.dropna(subset=self.feature_names)
|
||||
df_valid = df_valid.dropna(subset=self.feature_names)
|
||||
df_test = df_test.dropna(subset=self.feature_names)
|
||||
TimeInspector.log_cost_time("Finished dropping nan.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_replace_inf_feature(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
replace the 'inf' in feature with the mean of this dimension.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
def replace_inf(data):
|
||||
def process_inf(df):
|
||||
for col in df.columns:
|
||||
df[col] = df[col].replace([np.inf, -np.inf], df[col][~np.isinf(df[col])].mean())
|
||||
return df
|
||||
|
||||
data = data.groupby("datetime").apply(process_inf)
|
||||
data.sort_index(inplace=True)
|
||||
return data
|
||||
|
||||
df_train = replace_inf(df_train)
|
||||
df_valid = replace_inf(df_valid)
|
||||
df_test = replace_inf(df_test)
|
||||
TimeInspector.log_cost_time("Finished replace inf.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
def _process_normalize_feature(self, df_train, df_valid, df_test):
|
||||
"""
|
||||
Normalize data if needed, we provide two method now: min-max normalization and standard normalization.
|
||||
"""
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
if self.normalize_method == self.MINMAX_NORM:
|
||||
min_train = np.nanmin(df_train[self.feature_names].values, axis=0)
|
||||
max_train = np.nanmax(df_train[self.feature_names].values, axis=0)
|
||||
ignore = min_train == max_train
|
||||
|
||||
def normalize(x, min_train=min_train, max_train=max_train, ignore=ignore):
|
||||
if (~ignore).all():
|
||||
return (x - min_train) / (max_train - min_train)
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - min_train) / (max_train - min_train)
|
||||
return x
|
||||
|
||||
elif self.normalize_method == self.STD_NORM:
|
||||
mean_train = np.nanmean(df_train[self.feature_names].values, axis=0)
|
||||
std_train = np.nanstd(df_train[self.feature_names].values, axis=0)
|
||||
ignore = std_train == 0
|
||||
|
||||
def normalize(x, mean_train=mean_train, std_train=std_train, ignore=ignore):
|
||||
if (~ignore).all():
|
||||
return (x - mean_train) / std_train
|
||||
for i in range(ignore.size):
|
||||
if not ignore[i]:
|
||||
x[i] = (x[i] - mean_train) / std_train
|
||||
return x
|
||||
|
||||
else:
|
||||
raise ValueError("Normalize method {} is not allowed".format(self.normalize_method))
|
||||
|
||||
df_train.loc(axis=1)[self.feature_names] = normalize(df_train[self.feature_names].values)
|
||||
df_valid.loc(axis=1)[self.feature_names] = normalize(df_valid[self.feature_names].values)
|
||||
df_test.loc(axis=1)[self.feature_names] = normalize(df_test[self.feature_names].values)
|
||||
|
||||
TimeInspector.log_cost_time("Finished normalizing data.")
|
||||
|
||||
return df_train, df_valid, df_test
|
||||
|
||||
|
||||
class ConfigSectionProcessor(Processor):
|
||||
def __init__(self, feature_names, label_names, **kwargs):
|
||||
super().__init__(feature_names, label_names)
|
||||
# Options
|
||||
self.fillna_feature = kwargs.get("fillna_feature", True)
|
||||
self.fillna_label = kwargs.get("fillna_label", True)
|
||||
self.clip_feature_outlier = kwargs.get("clip_feature_outlier", False)
|
||||
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
|
||||
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
|
||||
|
||||
def __call__(self, *args):
|
||||
return [self._transform(x) for x in args]
|
||||
|
||||
def _transform(self, df):
|
||||
def _label_norm(x):
|
||||
x = x - x.mean() # copy
|
||||
x /= x.std()
|
||||
if self.clip_label_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.fillna_label:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
def _feature_norm(x):
|
||||
x = x - x.median() # copy
|
||||
x /= x.abs().median() * 1.4826
|
||||
if self.clip_feature_outlier:
|
||||
x.clip(-3, 3, inplace=True)
|
||||
if self.shrink_feature_outlier:
|
||||
x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)
|
||||
x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
|
||||
if self.fillna_feature:
|
||||
x.fillna(0, inplace=True)
|
||||
return x
|
||||
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
# Copy
|
||||
df_new = df.copy()
|
||||
|
||||
# Label
|
||||
cols = df.columns[df.columns.str.contains("^LABEL")]
|
||||
df_new[cols] = df[cols].groupby(level="datetime").apply(_label_norm)
|
||||
|
||||
# Features
|
||||
cols = df.columns[df.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_new[cols] = df[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_new[cols] = df[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
"KSFT",
|
||||
"OPEN",
|
||||
"HIGH",
|
||||
"LOW",
|
||||
"CLOSE",
|
||||
"VWAP",
|
||||
"ROC",
|
||||
"MA",
|
||||
"BETA",
|
||||
"RESI",
|
||||
"QTLU",
|
||||
"QTLD",
|
||||
"RSV",
|
||||
"SUMP",
|
||||
"SUMN",
|
||||
"SUMD",
|
||||
"VSUMP",
|
||||
"VSUMN",
|
||||
"VSUMD",
|
||||
]
|
||||
pat = "|".join(["^" + x for x in _cols])
|
||||
cols = df.columns[df.columns.str.contains(pat) & (~df.columns.isin(["HIGH0", "LOW0"]))]
|
||||
df_new[cols] = df[cols].groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
|
||||
df_new[cols] = df[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^RSQR")]
|
||||
df_new[cols] = df[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^MAX|^HIGH0")]
|
||||
df_new[cols] = df[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^MIN|^LOW0")]
|
||||
df_new[cols] = df[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^CORR|^CORD")]
|
||||
df_new[cols] = df[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df.columns[df.columns.str.contains("^WVMA")]
|
||||
df_new[cols] = df[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
TimeInspector.log_cost_time("Finished preprocessing data.")
|
||||
|
||||
return df_new
|
||||
@@ -1,317 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from scipy.stats import pearsonr
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from .handler import BaseDataHandler
|
||||
from .launcher import CONFIG_MANAGER
|
||||
from .fetcher import create_fetcher_with_config
|
||||
from ...utils import drop_nan_by_y_index, transform_end_date
|
||||
|
||||
|
||||
class BaseTrainer(object):
|
||||
def __init__(self, model_class, model_save_path, model_args, data_handler: BaseDataHandler, sacred_ex, **kwargs):
|
||||
# 1. Model.
|
||||
self.model_class = model_class
|
||||
self.model_save_path = model_save_path
|
||||
self.model_args = model_args
|
||||
|
||||
# 2. Data handler.
|
||||
self.data_handler = data_handler
|
||||
|
||||
# 3. Sacred ex.
|
||||
self.ex = sacred_ex
|
||||
|
||||
# 4. Logger.
|
||||
self.logger = get_module_logger("Trainer")
|
||||
|
||||
# 5. Data time
|
||||
self.train_start_date = kwargs.get("train_start_date", None)
|
||||
self.train_end_date = kwargs.get("train_end_date", None)
|
||||
self.validate_start_date = kwargs.get("validate_start_date", None)
|
||||
self.validate_end_date = kwargs.get("validate_end_date", None)
|
||||
self.test_start_date = kwargs.get("test_start_date", None)
|
||||
self.test_end_date = transform_end_date(kwargs.get("test_end_date", None))
|
||||
|
||||
@abstractmethod
|
||||
def train(self):
|
||||
"""
|
||||
Implement this method indicating how to train a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
"""
|
||||
Implement this method indicating how to restore a model and the data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_test_pred(self):
|
||||
"""
|
||||
Implement this method indicating how to get prediction result(s) from a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_test_performance(self):
|
||||
"""
|
||||
Implement this method indicating how to get the performance of the model.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement `get_test_performance`")
|
||||
|
||||
def get_test_score(self):
|
||||
"""
|
||||
Override this method to transfer the predict result(s) into the score of the stock.
|
||||
Note: If this is a multi-label training, you need to transfer predict labels into one score.
|
||||
Or you can just use the result of `get_test_pred()` (you can also process the result) if this is one label training.
|
||||
We use the first column of the result of `get_test_pred()` as default method (regard it as one label training).
|
||||
"""
|
||||
pred = self.get_test_pred()
|
||||
pred_score = pd.DataFrame(index=pred.index)
|
||||
pred_score["score"] = pred.iloc(axis=1)[0]
|
||||
return pred_score
|
||||
|
||||
|
||||
class StaticTrainer(BaseTrainer):
|
||||
def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
|
||||
super(StaticTrainer, self).__init__(model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs)
|
||||
self.model = None
|
||||
|
||||
split_data = self.data_handler.get_split_data(
|
||||
self.train_start_date,
|
||||
self.train_end_date,
|
||||
self.validate_start_date,
|
||||
self.validate_end_date,
|
||||
self.test_start_date,
|
||||
self.test_end_date,
|
||||
)
|
||||
(
|
||||
self.x_train,
|
||||
self.y_train,
|
||||
self.x_validate,
|
||||
self.y_validate,
|
||||
self.x_test,
|
||||
self.y_test,
|
||||
) = split_data
|
||||
|
||||
def train(self):
|
||||
TimeInspector.set_time_mark()
|
||||
model = self.model_class(**self.model_args)
|
||||
|
||||
if CONFIG_MANAGER.ex_config.finetune:
|
||||
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
|
||||
loader_model = fetcher.get_experiment(
|
||||
exp_name=CONFIG_MANAGER.ex_config.loader_name,
|
||||
exp_id=CONFIG_MANAGER.ex_config.loader_id,
|
||||
fields=["model"],
|
||||
)["model"]
|
||||
|
||||
if isinstance(loader_model, list):
|
||||
model_index = (
|
||||
-1
|
||||
if CONFIG_MANAGER.ex_config.loader_model_index is None
|
||||
else CONFIG_MANAGER.ex_config.loader_model_index
|
||||
)
|
||||
loader_model = loader_model[model_index]
|
||||
|
||||
model.load(loader_model)
|
||||
model.finetune(self.x_train, self.y_train, self.x_validate, self.y_validate)
|
||||
else:
|
||||
model.fit(self.x_train, self.y_train, self.x_validate, self.y_validate)
|
||||
model.save(self.model_save_path)
|
||||
self.ex.add_artifact(self.model_save_path)
|
||||
self.model = model
|
||||
TimeInspector.log_cost_time("Finished training model.")
|
||||
|
||||
def load(self):
|
||||
model = self.model_class(**self.model_args)
|
||||
|
||||
# Load model
|
||||
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
|
||||
loader_model = fetcher.get_experiment(
|
||||
exp_name=CONFIG_MANAGER.ex_config.loader_name,
|
||||
exp_id=CONFIG_MANAGER.ex_config.loader_id,
|
||||
fields=["model"],
|
||||
)["model"]
|
||||
|
||||
if isinstance(loader_model, list):
|
||||
model_index = (
|
||||
-1
|
||||
if CONFIG_MANAGER.ex_config.loader_model_index is None
|
||||
else CONFIG_MANAGER.ex_config.loader_model_index
|
||||
)
|
||||
loader_model = loader_model[model_index]
|
||||
|
||||
model.load(loader_model)
|
||||
|
||||
# Save model, after load, if you don't save the model, the result of this experiment will be no model
|
||||
model.save(self.model_save_path)
|
||||
self.ex.add_artifact(self.model_save_path)
|
||||
self.model = model
|
||||
|
||||
def get_test_pred(self):
|
||||
pred = self.model.predict(self.x_test)
|
||||
pred = pd.DataFrame(pred, index=self.x_test.index, columns=self.y_test.columns)
|
||||
return pred
|
||||
|
||||
def get_test_performance(self):
|
||||
try:
|
||||
model_score = self.model.score(self.x_test, self.y_test)
|
||||
except NotImplementedError:
|
||||
model_score = None
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
x_test, y_test, __ = drop_nan_by_y_index(self.x_test, self.y_test)
|
||||
pred_test = self.model.predict(x_test)
|
||||
model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
|
||||
|
||||
performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
|
||||
return performance
|
||||
|
||||
|
||||
class RollingTrainer(BaseTrainer):
|
||||
def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
|
||||
super(RollingTrainer, self).__init__(
|
||||
model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs
|
||||
)
|
||||
self.rolling_period = kwargs.get("rolling_period", 60)
|
||||
self.models = []
|
||||
self.rolling_data = []
|
||||
self.all_x_test = []
|
||||
self.all_y_test = []
|
||||
for data in self.data_handler.get_rolling_data(
|
||||
self.train_start_date,
|
||||
self.train_end_date,
|
||||
self.validate_start_date,
|
||||
self.validate_end_date,
|
||||
self.test_start_date,
|
||||
self.test_end_date,
|
||||
self.rolling_period,
|
||||
):
|
||||
self.rolling_data.append(data)
|
||||
__, __, __, __, x_test, y_test = data
|
||||
self.all_x_test.append(x_test)
|
||||
self.all_y_test.append(y_test)
|
||||
|
||||
def train(self):
|
||||
# 1. Get total data parts.
|
||||
# total_data_parts = self.data_handler.total_data_parts
|
||||
# self.logger.warning('Total numbers of model are: {}, start training models...'.format(total_data_parts))
|
||||
if CONFIG_MANAGER.ex_config.finetune:
|
||||
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
|
||||
loader_model = fetcher.get_experiment(
|
||||
exp_name=CONFIG_MANAGER.ex_config.loader_name,
|
||||
exp_id=CONFIG_MANAGER.ex_config.loader_id,
|
||||
fields=["model"],
|
||||
)["model"]
|
||||
loader_model_index = CONFIG_MANAGER.ex_config.loader_model_index
|
||||
previous_model_path = ""
|
||||
# 2. Rolling train.
|
||||
for (
|
||||
index,
|
||||
(x_train, y_train, x_validate, y_validate, x_test, y_test),
|
||||
) in enumerate(self.rolling_data):
|
||||
TimeInspector.set_time_mark()
|
||||
model = self.model_class(**self.model_args)
|
||||
|
||||
if CONFIG_MANAGER.ex_config.finetune:
|
||||
# Finetune model
|
||||
if loader_model_index is None and isinstance(loader_model, list):
|
||||
try:
|
||||
model.load(loader_model[index])
|
||||
except IndexError:
|
||||
# Load model by previous_model_path
|
||||
with open(previous_model_path, "rb") as fp:
|
||||
model.load(fp)
|
||||
model.finetune(x_train, y_train, x_validate, y_validate)
|
||||
else:
|
||||
|
||||
if index == 0:
|
||||
loader_model = (
|
||||
loader_model[loader_model_index] if isinstance(loader_model, list) else loader_model
|
||||
)
|
||||
model.load(loader_model)
|
||||
else:
|
||||
with open(previous_model_path, "rb") as fp:
|
||||
model.load(fp)
|
||||
|
||||
model.finetune(x_train, y_train, x_validate, y_validate)
|
||||
|
||||
else:
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
|
||||
model_save_path = "{}_{}".format(self.model_save_path, index)
|
||||
model.save(model_save_path)
|
||||
previous_model_path = model_save_path
|
||||
self.ex.add_artifact(model_save_path)
|
||||
self.models.append(model)
|
||||
TimeInspector.log_cost_time("Finished training model: {}.".format(index + 1))
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
Load the data and the model
|
||||
"""
|
||||
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
|
||||
loader_model = fetcher.get_experiment(
|
||||
exp_name=CONFIG_MANAGER.ex_config.loader_name,
|
||||
exp_id=CONFIG_MANAGER.ex_config.loader_id,
|
||||
fields=["model"],
|
||||
)["model"]
|
||||
for index in range(len(self.all_x_test)):
|
||||
model = self.model_class(**self.model_args)
|
||||
|
||||
model.load(loader_model[index])
|
||||
|
||||
# Save model
|
||||
model_save_path = "{}_{}".format(self.model_save_path, index)
|
||||
model.save(model_save_path)
|
||||
self.ex.add_artifact(model_save_path)
|
||||
|
||||
self.models.append(model)
|
||||
|
||||
def get_test_pred(self):
|
||||
"""
|
||||
Predict the score on test data with the models.
|
||||
Please ensure the models and data are loaded before call this score.
|
||||
|
||||
:return: the predicted scores for the pred
|
||||
"""
|
||||
pred_df_list = []
|
||||
y_test_columns = self.all_y_test[0].columns
|
||||
# Start iteration.
|
||||
for model, x_test in zip(self.models, self.all_x_test):
|
||||
pred = model.predict(x_test)
|
||||
pred_df = pd.DataFrame(pred, index=x_test.index, columns=y_test_columns)
|
||||
pred_df_list.append(pred_df)
|
||||
return pd.concat(pred_df_list)
|
||||
|
||||
def get_test_performance(self):
|
||||
"""
|
||||
Get the performances of the models
|
||||
|
||||
:return: the performances of models
|
||||
"""
|
||||
pred_test_list = []
|
||||
y_test_list = []
|
||||
scorer = self.models[0]._scorer
|
||||
for model, x_test, y_test in zip(self.models, self.all_x_test, self.all_y_test):
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
|
||||
pred_test_list.append(model.predict(x_test))
|
||||
y_test_list.append(np.squeeze(y_test.values))
|
||||
|
||||
pred_test_array = np.concatenate(pred_test_list, axis=0)
|
||||
y_test_array = np.concatenate(y_test_list, axis=0)
|
||||
|
||||
model_score = scorer(y_test_array, pred_test_array)
|
||||
model_pearsonr = pearsonr(np.ravel(y_test_array), np.ravel(pred_test_array))[0]
|
||||
|
||||
performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
|
||||
return performance
|
||||