1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

Merge pull request #56 from you-n-g/main

Refactor Qlib interface
This commit is contained in:
you-n-g
2020-11-29 21:23:31 +08:00
committed by GitHub
168 changed files with 13244 additions and 4725 deletions

View File

@@ -43,16 +43,17 @@ jobs:
- name: Lint with Black
run: |
cd ..
python -m black qlib -l 120
python -m black qlib -l 120 --check --diff
- name: Unit tests with Pytest
run: |
cd tests
pytest . --durations=0
- name: Test data downloads and examples
- name: Test data downloads
run: |
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
# cd examples
# estimator -c estimator/estimator_config.yaml
# jupyter nbconvert --execute estimator/analyze_from_estimator.ipynb --to html
- name: Test workflow by config
run: |
qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml

112
README.md
View File

@@ -9,7 +9,7 @@
<p align="center">
<img src="http://fintech.msra.cn/images/logo/1.png" />
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
</p>
@@ -28,6 +28,8 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
- [Quant Model Zoo](#quant-model-zoo)
- [Run a single model](#run-a-single-model)
- [Run multiple models](#run-multiple-models)
- [Quant Dataset Zoo](#quant-dataset-zoo)
- [More About Qlib](#more-about-qlib)
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
@@ -39,19 +41,17 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
# Framework of Qlib
<div style="align: center">
<img src="http://fintech.msra.cn/images/framework.png" />
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
</div>
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
| Name | Description |
| ------ | ----- |
| `Data layer` | `DataServer` focuses on providing high-performance infrastructure for users to manage and retrieve raw data. `DataEnhancement` will preprocess the data and provide the best dataset to be fed into the models. |
| `Interday Model` | `Interday model` focuses on producing prediction scores (aka. _alpha_). Models are trained by `Model Creator` and managed by `Model Manager`. Users could choose one or multiple models for prediction. Multiple models could be combined with `Ensemble` module. |
| `Interday Strategy` | `Portfolio Generator` will take prediction scores as input and output the orders based on the current position to achieve the target portfolio. |
| `Intraday Trading` | `Order Executor` is responsible for executing orders output by `Interday Strategy` and returning the executed results. |
| `Analysis` | Users could get a detailed analysis report of forecasting signals and portfolios in this part. |
| Name | Description |
| ------ | ----- |
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
* The modules with hand-drawn style are under development and will be released in the future.
* The modules with dashed borders are highly user-customizable and extendible.
@@ -128,50 +128,53 @@ Users could create the same dataset with it.
-->
## Auto Quant Research Workflow
Qlib provides a tool named `Estimator` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
1. Quant Research Workflow: Run `Estimator` with [estimator_config.yaml](examples/estimator/estimator_config.yaml) as following. (*Please note that this may **not work** under MacOS with Python 3.8 due to the incompatibility of the `sacred` package we use with Python 3.8. We will fix this bug in the future.*)
1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml)) as following.
```bash
cd examples # Avoid running program under the directory contains `qlib`
estimator -c estimator/estimator_config.yaml
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
```
The result of `Estimator` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
The result of `qrun` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
```bash
risk
excess_return_without_cost mean 0.000675
std 0.005456
annualized_return 0.170077
information_ratio 1.963824
max_drawdown -0.063646
excess_return_with_cost mean 0.000479
std 0.005453
annualized_return 0.120776
information_ratio 1.395116
max_drawdown -0.071216
'The following are analysis results of the excess return without cost.'
risk
mean 0.000708
std 0.005626
annualized_return 0.178316
information_ratio 1.996555
max_drawdown -0.081806
'The following are analysis results of the excess return with cost.'
risk
mean 0.000512
std 0.005626
annualized_return 0.128982
information_ratio 1.444287
max_drawdown -0.091078
```
Here are detailed documents for [Estimator](https://qlib.readthedocs.io/en/latest/component/estimator.html).
Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
2. Graphical Reports Analysis: Run `examples/estimator/analyze_from_estimator.ipynb` with `jupyter notebook` to get graphical reports
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
- Forecasting signal (model prediction) analysis
- Cumulative Return of groups
![Cumulative Return](http://fintech.msra.cn/images/analysis/analysis_model_cumulative_return.png?v=0.1)
![Cumulative Return](http://fintech.msra.cn/images_v060/analysis/analysis_model_cumulative_return.png?v=0.1)
- Return distribution
![long_short](http://fintech.msra.cn/images/analysis/analysis_model_long_short.png?v=0.1)
![long_short](http://fintech.msra.cn/images_v060/analysis/analysis_model_long_short.png?v=0.1)
- Information Coefficient (IC)
![Information Coefficient](http://fintech.msra.cn/images/analysis/analysis_model_IC.png?v=0.1)
![Monthly IC](http://fintech.msra.cn/images/analysis/analysis_model_monthly_IC.png?v=0.1)
![IC](http://fintech.msra.cn/images/analysis/analysis_model_NDQ.png?v=0.1)
![Information Coefficient](http://fintech.msra.cn/images_v060/analysis/analysis_model_IC.png?v=0.1)
![Monthly IC](http://fintech.msra.cn/images_v060/analysis/analysis_model_monthly_IC.png?v=0.1)
![IC](http://fintech.msra.cn/images_v060/analysis/analysis_model_NDQ.png?v=0.1)
- Auto Correlation of forecasting signal (model prediction)
![Auto Correlation](http://fintech.msra.cn/images/analysis/analysis_model_auto_correlation.png?v=0.1)
![Auto Correlation](http://fintech.msra.cn/images_v060/analysis/analysis_model_auto_correlation.png?v=0.1)
- Portfolio analysis
- Backtest return
![Report](http://fintech.msra.cn/images/analysis/report.png?v=0.1)
![Report](http://fintech.msra.cn/images_v060/analysis/report.png?v=0.1)
<!--
- Score IC
![Score IC](docs/_static/img/score_ic.png)
@@ -184,21 +187,54 @@ Qlib provides a tool named `Estimator` to run the whole workflow automatically (
-->
## Building Customized Quant Research Workflow by Code
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/train_backtest_analyze.ipynb) is a demo for customized Quant research workflow by code
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
# Quant Model Zoo
# [Quant Model Zoo](examples/benchmarks)
Here is a list of models built on `Qlib`.
- [GBDT based on lightgbm](qlib/contrib/model/gbdt.py)
- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
- [GBDT based on Catboost](qlib/contrib/model/catboost_model.py)
- [GBDT based on XGBoost](qlib/contrib/model/xgboost.py)
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
<!-- - [TFT based on tensorflow](examples/benchmarks/TFT/tft.py) -->
Your PR of new Quant models is highly welcomed.
## Run a single model
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
## Run multiple models
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored. (**Note**: the script will erase your previous experiment records created by running itself.)
Here is an example of running all the models for 10 iterations:
```python
python run_all_model.py 10
```
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
# Quant Dataset Zoo
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`.
- [Alpha360](./qlib/contrib/estimator/handler.py)
- [Alpha158](./qlib/contrib/estimator/handler.py)
| Dataset | US Market | China Market |
| -- | -- | -- |
| [Alpha360](./qlib/contrib/data/handler.py) | √ | √ |
| [Alpha158](./qlib/contrib/data/handler.py) | √ | √ |
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
Your PR to build new Quant dataset is highly welcomed.

View File

@@ -62,6 +62,7 @@ It sees the key of the redis lock has existed in your redis db now. You can use
> select 1
> flushdb
If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If so, try using ``flushall`` to clear all the keys.
.. note::

Binary file not shown.

Before

Width:  |  Height:  |  Size: 40 KiB

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 52 KiB

After

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 163 KiB

After

Width:  |  Height:  |  Size: 160 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 47 KiB

After

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 105 KiB

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 205 KiB

After

Width:  |  Height:  |  Size: 271 KiB

View File

@@ -1,4 +1,5 @@
.. _alpha:
===========================
Building Formulaic Alphas
===========================
@@ -49,7 +50,7 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
.. code-block:: python
>> from qlib.contrib.estimator.handler import QLibDataHandler
>> from qlib.data.dataset.handler import QLibDataHandler
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
>> fields = [MACD_EXP] # MACD
>> names = ['MACD']

View File

@@ -1,4 +1,5 @@
.. _server:
=================================
``Online`` & ``Offline`` mode
=================================

View File

@@ -1,4 +1,5 @@
.. _backtest:
============================================
Intraday Trading: Model&Strategy Testing
============================================

View File

@@ -1,6 +1,7 @@
.. _data:
================================
Data Layer: Data Framework&Usage
Data Layer: Data Framework & Usage
================================
Introduction
@@ -14,7 +15,9 @@ The introduction of ``Data Layer`` includes the following parts.
- Data Preparation
- Data API
- Data Loader
- Data Handler
- Dataset
- Cache
- Data and Cache File Structure
@@ -30,13 +33,19 @@ Such data will be stored with filename suffix `.bin` (We'll call them `.bin` fil
Qlib Format Dataset
--------------------
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the dataset as follows.
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
After running the above command, users can find china-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory.
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
After running the above command, users can find china-stock and us-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
@@ -48,12 +57,45 @@ Converting CSV Format into Qlib Format
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert data in CSV format into `.bin` files (Qlib format).
Users can download the china-stock data in CSV format as follows for reference to the CSV format.
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
.. code-block:: bash
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
- Name the CSV file after a stock: `SH600000.csv`, `AAPL.csv` (not case sensitive).
- CSV file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
.. code-block:: bash
python scripts/dump_bin.py dump_all ... --symbol_field_name symbol
where the data are in the following format:
.. code-block::
symbol,close
SH600000,120
- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
.. code-block:: bash
python scripts/dump_bin.py dump_all ... --date_field_name date
where the data are in the following format:
.. code-block::
symbol,date,close,open,volume
SH600000,2020-11-01,120,121,12300000
SH600000,2020-11-02,123,120,12300000
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
@@ -61,6 +103,12 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
For other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:
.. code-block:: bash
python dump_bin.py dump_all --help
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.
.. note::
@@ -96,9 +144,8 @@ China-Stock Mode & US-Stock Mode
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=REG_CN)
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` does not provide a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
- Prepare data in CSV format
- Convert data from CSV format to Qlib format, please refer to section `Converting CSV Format into Qlib Format <#converting-csv-format-into-qlib-format>`_.
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Initialize ``Qlib`` in US-stock mode
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
@@ -140,72 +187,97 @@ Filter
Expression dynamic instrument filter. Filter the instruments based on a certain expression. An expression rule indicating a certain feature field is required.
- `basic features filter`: rule_expression = '$close/$open>5'
- `cross-sectional features filter` : rule_expression = '$rank($close)<10'
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
Reference
-------------
To know more about ``Data API``, please refer to `Data API <../reference/api.html#data>`_.
Data Loader
=================
``Data Loader`` in ``Qlib`` is designed to load raw data from the original data source. It will be loaded and used in the ``Data Handler`` module.
QlibDataLoader
---------------
The ``QlibDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from the data source.
Interface
------------
Here are some interfaces of the ``QlibDataLoader`` class:
.. autoclass:: qlib.data.dataset.loader.QlibDataLoader
:members: load, load_group_df
API
-----------
To know more about ``Data Loader``, please refer to `Data Loader API <../reference/api.html#module-qlib.data.dataset.loader>`_.
Data Handler
=================
Users can use ``Data Handler`` in an automatic workflow by ``Estimator``, refer to `Estimator: Workflow Management <estimator.html>`_ for more details.
The ``Data Handler`` module in ``Qlib`` is designed to handler those common data processing methods which will be used by most of the models.
Also, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data(standardization, remove NaN, etc.) and build datasets. It is a subclass of ``qlib.contrib.estimator.handler.BaseDataHandler``, which provides some interfaces as follows.
Users can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management <workflow.html>`_ for more details.
Base Class & Interface
DataHandlerLP
--------------
In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some leanable ``Processors`` which can learn the parameters of data processing. When new data comes in, these `trained` ``Processors`` can then infer on the new data and thus processing real-time data in an efficient way. More information about ``Processors`` will be listed in the next subsection.
Interface
----------------------
Qlib provides a base class `qlib.contrib.estimator.BaseDataHandler <../reference/api.html#qlib.contrib.estimator.handler.BaseDataHandler>`_, which provides the following interfaces:
Here are some important interfaces that ``DataHandlerLP`` provides:
- `setup_feature`
Implement the interface to load the data features.
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
:members: __init__, fetch, get_cols
- `setup_label`
Implement the interface to load the data labels and calculate the users' labels.
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
- `setup_processed_data`
Implement the interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on.
Qlib also provides two functions to help users init the data handler, users can override them for users' needs.
- `_init_kwargs`
Users can init the kwargs of the data handler in this function, some kwargs may be used when init the raw df.
Kwargs are the other attributes in data.args, like dropna_label, dropna_feature
- `_init_raw_df`
Users can init the raw df, feature names, and label names of data handler in this function.
If the index of feature df and label df are not the same, users need to override this method to merge them (e.g. inner, left, right merge).
If users want to load features and labels by config, users can inherit ``qlib.contrib.estimator.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
Usage
--------------
Processor
----------
``Data Handler`` can be used as a single module, which provides the following mehtods:
The ``Processor`` module in ``Qlib`` is designed to be learnable and it is responsible for handling data processing such as `normalization` and `drop none/nan features/labels`.
- `get_split_data`
- According to the start and end dates, return features and labels of the pandas DataFrame type used for the 'Model'
- `get_rolling_data`
- According to the start and end dates, and `rolling_period`, an iterator is returned, which can be used to traverse the features and labels used for rolling.
``Qlib`` provides the following ``Processors``:
- ``DropnaProcessor``: `processor` that drops N/A features.
- ``DropnaLabel``: `processor` that drops N/A labels.
- ``TanhProcess``: `processor` that uses `tanh` to process noise data.
- ``ProcessInf``: `processor` that handles infinity values, it will be replaces by the mean of the column.
- ``Fillna``: `processor` that handles N/A values, which will fill the N/A value by 0 or other given number.
- ``MinMaxNorm``: `processor` that applies min-max normalization.
- ``ZscoreNorm``: `processor` that applies z-score normalization.
- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.
- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.
- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).
To know more about ``Processor``, please refer to `Processor API <../reference/api.html#module-qlib.data.dataset.processor>`_.
Example
--------------
``Data Handler`` can be run with ``estimator`` by modifying the configuration file, and can also be used as a single module.
``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module.
Know more about how to run ``Data Handler`` with ``Estimator``, please refer to `Estimator: Workflow Management <estimator.html>`_
Know more about how to run ``Data Handler`` with ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_
Qlib provides implemented data handler `Alpha158`. The following example shows how to run `Alpha158` as a single module.
@@ -214,44 +286,54 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
.. code-block:: Python
from qlib.contrib.estimator.handler import Alpha158
from qlib.contrib.model.gbdt import LGBModel
import qlib
from qlib.contrib.data.handler import Alpha158
DATA_HANDLER_CONFIG = {
"dropna_label": True,
"start_date": "2007-01-01",
"end_date": "2020-08-01",
"market": "csi300",
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi300",
}
TRAINER_CONFIG = {
"train_start_date": "2007-01-01",
"train_end_date": "2014-12-31",
"validate_start_date": "2015-01-01",
"validate_end_date": "2016-12-31",
"test_start_date": "2017-01-01",
"test_end_date": "2020-08-01",
}
if __name__ == "__main__":
qlib.init()
h = Alpha158(**data_handler_config)
exampleDataHandler = Alpha158(**DATA_HANDLER_CONFIG)
# get all the columns of the data
print(h.get_cols())
# example of 'get_split_data'
x_train, y_train, x_validate, y_validate, x_test, y_test = exampleDataHandler.get_split_data(**TRAINER_CONFIG)
# fetch all the labels
print(h.fetch(col_set="label"))
# example of 'get_rolling_data'
for (x_train, y_train, x_validate, y_validate, x_test, y_test) in exampleDataHandler.get_rolling_data(**TRAINER_CONFIG):
print(x_train, y_train, x_validate, y_validate, x_test, y_test)
.. note:: (x_train, y_train, x_validate, y_validate, x_test, y_test) can be used as arguments for the `fit`, `predic``, and `score` methods of the ``Interday Model`` , please refer to `Model <model.html#base-class-interface>`_.
Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``.
# fetch all the features
print(h.fetch(col_set="feature"))
API
---------
To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.contrib.estimator.handler>`_.
To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.data.dataset.handler>`_.
Dataset
=================
The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the rights to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most important interface of the class:
.. autoclass:: qlib.data.dataset.__init__.DatasetH
:members:
API
---------
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
Cache
==========

View File

@@ -1,706 +0,0 @@
.. _estimator:
=================================
Estimator: Workflow Management
=================================
.. currentmodule:: qlib
Introduction
===================
The components in `Qlib Framework <../introduction/introduction.html#framework>`_ are designed in a loosely-coupled way. Users could build their own Quant research workflow with these components like `Example <https://github.com/microsoft/qlib/blob/main/examples/train_and_backtest.py>`_
Besides, ``Qlib`` provides more user-friendly interfaces named ``Estimator`` to automatically run the whole workflow defined by configuration. A concrete execution of the whole workflow is called an `experiment`.
With ``Estimator``, user can easily run an `experiment`, which includes the following steps:
- Data
- Loading
- Processing
- Slicing
- Model
- Training and inference(static or rolling)
- Saving & loading
- Evaluation(Back-testing)
For each `experiment`, ``Qlib`` will capture the model training details, performance evaluation results and basic information (e.g. names, ids). The captured data will be stored in backend-storage (disk or database).
Complete Example
===================
Before getting into details, here is a complete example of ``Estimator``, which defines the workflow in typical Quant research.
Below is a typical config file of ``Estimator``.
.. code-block:: YAML
experiment:
name: estimator_example
observer_type: file_storage
mode: train
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
args:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
data:
class: Alpha158
args:
dropna_label: True
filter:
market: csi500
trainer:
class: StaticTrainer
args:
rolling_period: 360
train_start_date: 2007-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-12-31
test_start_date: 2017-01-01
test_end_date: 2020-08-01
strategy:
class: TopkDropoutStrategy
args:
topk: 50
n_drop: 5
backtest:
normal_backtest_args:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: SH000905
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
qlib_data:
# when testing, please modify the following parameters according to the specific environment
provider_uri: "~/.qlib/qlib_data/cn_data"
region: "cn"
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
.. code-block:: bash
estimator -c configuration.yaml
.. note:: `estimator` will be placed in your $PATH directory when installing ``Qlib``.
Configuration File
===================
Let's get into details of ``Estimator`` in this section.
Before using ``estimator``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.
Experiment Section
--------------------
At first, the configuration file needs to contain a section named `experiment` about the basic information. This section describes how `estimator` tracks and persists current `experiment`. ``Qlib`` used `sacred`, a lightweight open-source tool, to configure, organize, generate logs, and manage experiment results. Partial behaviors of `sacred` will base on the `experiment` section.
Following files will be saved by `sacred` after `estimator` finish an `experiment`:
- `model.bin`, model binary file
- `pred.pkl`, model prediction result file
- `analysis.pkl`, backtest performance analysis file
- `positions.pkl`, backtest position records file
- `run`, the experiment information object, usually contains some meta information such as the experiment name, experiment date, etc.
Here is the typical configuration of `experiment section`
.. code-block:: YAML
experiment:
name: test_experiment
observer_type: mongo
mongo_url: mongodb://MONGO_URL
db_name: public
finetune: false
exp_info_path: /home/test_user/exp_info.json
mode: test
loader:
id: 677
The meaning of each field is as follows:
- `name`
The experiment name, str type, `sacred <https://github.com/IDSIA/sacred>_` will use this experiment name as an identifier for some important internal processes. Users can find this field in `run` object of `sacred`. The default value is `test_experiment`.
- `observer_type`
Observer type, str type, there are two choices which include `file_storage` and `mongo` respectively. If `file_storage` is selected, all the above-mentioned managed contents will be stored in the `dir` directory, separated by the number of times of experiments as a subfolder. If it is `mongo`, the content will be stored in the database. The default is `file_storage`.
- For `file_storage` observer.
- `dir`
Directory URL, str type, directory for `file_storage` observer type, files captured and managed by sacred with `file_storage` observer will be saved to this directory, which is the same directory as `config.json` by default.
- For `mongo` observer.
- `mongo_url`
Database URL, str type, required if the observer type is `mongo`.
- `db_name`
Database name, str type, required if the observer type is `mongo`.
- `finetune`
``Estimator``'s behaviors to train models will base on this flag.
If you just want to train models from scratch each time instead of based on existing models, please leave `finetune=false`. Otherwise please read the
details below.
The following table is the processing logic for different situations.
========== =========================================== ==================================== =========================================== ==========================================
. Static Rolling
. finetune:true finetune:false finetune:true finetune:false
========== =========================================== ==================================== =========================================== ==========================================
Train - Need to provide model (Static or Rolling) - No need to provide model - Need to provide model (Static or Rolling) - Need to provide model (Static or Rolling)
- The args in model section will be - The args in model section will be - The args in model section will be - The args in model section will be
used for finetuning used for training used for finetuning used for finetuning
- Update based on the provided model - Train model from scratch - Update based on the provided model - Based on the provided model update
and parameters and parameters - Train model from scratch
- **Each rolling time slice is based on** - **Train each rolling time slice**
**a model updated from the previous** **separately**
**time**
Test - Model must exist, otherwise an exception will be raised.
- For `StaticTrainer`, users need to train a model and record 'exp_info' for 'Test'.
- For `RollingTrainer`, users need to train a set of models until the latest time, and record 'exp_info' for 'Test'.
========== =============================================================================================================================================================================
.. note::
1. finetune parameters: share model.args parameters.
2. provide model: from `loader.model_index`, load the index of the model(starting from 0).
3. If `loader.model_index` is None:
- In 'Static Finetune=True', if provide 'Rolling', use the last model to update.
- For `RollingTrainer` with Finetune=True.
- If `StaticTrainer` is used in loader, the model will be used for initialization for finetuning.
- If `RollingTrainer` is used in loader, the existing models will be used without any modification and the new models will be initialized with the model in the last period and finetune one by one.
- `exp_info_path`
save path of experiment info, str type, save the experiment info and model `prediction score` after the experiment is finished. Optional parameter, the default value is `<config_file_dir>/ex_name/exp_info.json`.
- `mode`
`train` or `test`, str type.
- `test mode` is designed for inference. Under `test mode`, it will load the model according to the parameters of `loader` and skip model training.
- `train model` is the default value. It will train new models by default and
Please note that when it fails to load model, it will fall back to `fit` model.
.. note::
if users choose ` test mode`, they need to make sure:
- The loader of `test_start_date` must be less than or equal to the current `test_start_date`.
- If other parameters of the `loader` model args are different, a warning will appear.
- `loader`
If you just want to train models from scratch each time instead of based on existing models, please ignore `loader` section. Otherwise please read the
details below.
The `loader` section only works when the `mode` is `test` or `finetune` is `true`.
- `model_index`
Model index, int type. The index of the loaded model in loader_models (starting at 0) for the first `finetune`. The default value is None.
- `exp_info_path`
Loader model experiment info path, str type. If the field exists, the following parameters will be parsed from `exp_info_path`, and the following parameters will not work. One of this field and `id` must exist at least .
- `id`
The experiment id of the model that needs to be loaded, int type. If the `mode` is `test`, this value is required. This field and `exp_info_path` must exist one.
- `name`
The experiment name of the model that needs to be loaded, str type. The default value is the current experiment `name`.
- `observer_type`
The experiment observer type of the model that needs to be loaded, str type. The default value is the current experiment `observer_type`.
.. note:: The observer type is a concept of the `sacred` module, which determines how files, standard input, and output which are managed by sacred are stored.
- `file_storage`
If `observer_type` is `file_storage`, the config may be as follows.
.. code-block:: YAML
experiment:
name: test_experiment
dir: <path to a directory> # default is dir of `config.yml`
observer_type: file_storage
- `mongo`
If `observer_type` is `mongo`, the config may be as follows.
.. code-block:: YAML
experiment:
name: test_experiment
observer_type: mongo
mongo_url: mongodb://MONGO_URL
db_name: public
Users need to indicate `mongo_url` and `db_name` for a mongo observer.
.. note::
If users choose the mongo observer, they need to make sure:
- Have an environment with the mongodb installed and a mongo database dedicated to storing the results of the experiments.
- The python environment (the version of python and package) to run the experiments and the one to fetch the results are consistent.
Model Section
-----------------
Users can use a specified model by configuration with hyper-parameters.
Custom Models
~~~~~~~~~~~~~~~~~
Qlib supports custom models, but it must be a subclass of the `qlib.contrib.model.Model`, the config for a custom model may be as following.
.. code-block:: YAML
model:
class: SomeModel
module_path: /tmp/my_experment/custom_model.py
args:
loss: binary
The class `SomeModel` should be in the module `custom_model`, and ``Qlib`` could parse the `module_path` to load the class.
To know more about ``Interday Model``, please refer to `Interday Model: Training & Prediction <model.html>`_.
Data Section
-----------------
``Data Handler`` can be used to load raw data, prepare features and label columns, preprocess data (standardization, remove NaN, etc.), split training, validation, and test sets. It is a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`.
Users can use the specified data handler by config as follows.
.. code-block:: YAML
data:
class: Alpha158
args:
start_date: 2005-01-01
end_date: 2018-04-30
dropna_label: True
filter:
market: csi500
filter_pipeline:
-
class: NameDFilter
module_path: qlib.filter
args:
name_rule_re: S(?!Z3)
fstart_time: 2018-01-01
fend_time: 2018-12-11
-
class: ExpressionDFilter
module_path: qlib.filter
args:
rule_expression: $open/$factor<=45
fstart_time: 2018-01-01
fend_time: 2018-12-11
- `class`
Data handler class, str type, which should be a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`, and implements 5 important interfaces for loading features, loading raw data, preprocessing raw data, slicing train, validation, and test data. The default value is `ALPHA360`. If users want to write a data handler to retrieve the data in ``Qlib``, `QlibDataHandler` is suggested.
- `module_path`
The module path, str type, absolute url is also supported, indicates the path of the `class` implementation of the data processor class. The default value is `qlib.contrib.estimator.handler`.
- `args`
Parameters used for ``Data Handler`` initialization.
- `train_start_date`
Training start time, str type, the default value is `2005-01-01`.
- `start_date`
Data start date, str type.
- `end_date`
Data end date, str type. the data from start_date to end_date decides which part of data will be loaded in `datahandler`, users can only use these data in the following parts.
- `dropna_feature` (Optional in args)
Drop Nan feature, bool type, the default value is False.
- `dropna_label` (Optional in args)
Drop Nan label, bool type, the default value is True. Some multi-label tasks will use this.
- `normalize_method` (Optional in args)
Normalize data by a given method. str type. ``Qlib`` gives two normalizing methods, `MinMax` and `Std`.
If users want to build their own method, please override `_process_normalize_feature`.
- `filter`
Dynamically filtering the stocks based on the filter pipeline.
- `market`
index name, str type, the default value is `csi500`.
- `filter_pipeline`
Filter rule list, list type, the default value is []. Can be customized according to users' needs.
- `class`
Filter class name, str type.
- `module_path`
The module path, str type.
- `args`
The filter class parameters, these parameters are set according to the `class`, and all the parameters as kwargs to `class`.
Custom Data Handler
~~~~~~~~~~~~~~~~~~~~~~
Qlib support custom data handler, but it must be a subclass of the ``qlib.contrib.estimator.handler.BaseDataHandler``, the config for custom data handler may be as follows.
.. code-block:: YAML
data:
class: SomeDataHandler
module_path: /tmp/my_experment/custom_data_handler.py
args:
start_date: 2005-01-01
end_date: 2018-04-30
The class `SomeDataHandler` should be in the module `custom_data_handler`, and ``Qlib`` could parse the `module_path` to load the class.
If users want to load features and labels by config, they can inherit ``qlib.contrib.estimator.handler.ConfigDataHandler``, ``Qlib`` also has provided some preprocess methods in this subclass.
If users want to use qlib data, `QLibDataHandler` is recommended, from which users can inherit the custom class. `QLibDataHandler` is also a subclass of `ConfigDataHandler`.
To know more about ``Data Handler``, please refer to `Data Framework&Usage <data.html>`_.
Trainer Section
-----------------
Users can specify the trainer ``Trainer`` by the config file, which is a subclass of ``qlib.contrib.estimator.trainer.BaseTrainer`` and implement three important interfaces for training the model, restoring the model, and getting model predictions as follows.
- `train`
Implement this interface to train the model.
- `load`
Implement this interface to recover the model from disk.
- `get_pred`
Implement this interface to get model prediction results.
Qlib have provided two implemented trainer,
- `StaticTrainer`
The static trainer will be trained using the training, validation, and test data of the data processor static slicing.
- `RollingTrainer`
The rolling trainer will use the rolling iterator of the data processor to split data for rolling training.
Users can specify `trainer` with the configuration file:
.. code-block:: YAML
trainer:
class: StaticTrainer # or RollingTrainer
args:
rolling_period: 360
train_start_date: 2005-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-06-30
test_start_date: 2016-07-01
test_end_date: 2017-07-31
- `class`
Trainer class, which should be a subclass of `qlib.contrib.estimator.trainer.BaseTrainer`, and needs to implement three important interfaces, the default value is `StaticTrainer`.
- `module_path`
The module path, str type, absolute url is also supported, indicates the path of the trainer class implementation.
- `args`
Parameters used for ``Trainer`` initialization.
- `rolling_period`
The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. Only used in `RollingTrainer`.
- `train_start_date`
Training start time, str type.
- `train_end_date`
Training end time, str type.
- `validate_start_date`
Validation start time, str type.
- `validate_end_date`
Validation end time, str type.
- `test_start_date`
Test start time, str type.
- `test_end_date`
Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`.
Custom Trainer
~~~~~~~~~~~~~~~~~~
Qlib supports custom trainer, but it must be a subclass of the `qlib.contrib.estimator.trainer.BaseTrainer`, the config for a custom trainer may be as following:
.. code-block:: YAML
trainer:
class: SomeTrainer
module_path: /tmp/my_experment/custom_trainer.py
args:
train_start_date: 2005-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-06-30
test_start_date: 2016-07-01
test_end_date: 2017-07-31
The class `SomeTrainer` should be in the module `custom_trainer`, and ``Qlib`` could parse the `module_path` to load the class.
Strategy Section
-----------------
Users can specify strategy through a config file, for example:
.. code-block:: YAML
strategy :
class: TopkDropoutStrategy
args:
topk: 50
n_drop: 5
- `class`
The strategy class, str type, should be a subclass of `qlib.contrib.strategy.strategy.BaseStrategy`. The default value is `TopkDropoutStrategy`.
- `module_path`
The module location, str type, absolute url is also supported, and absolute path is also supported, indicates the location of the policy class implementation.
- `args`
Parameters used for ``Trainer`` initialization.
- `topk`
The number of stocks in the portfolio
- `n_drop`
Number of stocks to be replaced in each trading date
Custom Strategy
^^^^^^^^^^^^^^^^^^^
Qlib supports custom strategy, but it must be a subclass of the ``qlib.contrib.strategy.strategy.BaseStrategy``, the config for custom strategy may be as following:
.. code-block:: YAML
strategy :
class: SomeStrategy
module_path: /tmp/my_experment/custom_strategy.py
The class `SomeStrategy` should be in the module `custom_strategy`, and ``Qlib`` could parse the `module_path` to load the class.
To know more about ``Strategy``, please refer to `Strategy <strategy.html>`_.
Backtest Section
-----------------
Users can specify `backtest` through a config file, for example:
.. code-block:: YAML
backtest :
normal_backtest_args:
topk: 50
benchmark: SH000905
account: 500000
deal_price: close
min_cost: 5
subscribe_fields:
- $close
- $change
- $factor
- `normal_backtest_args`
Normal backtest parameters. All the parameters in this section will be passed to the ``qlib.contrib.evaluate.backtest`` function in the form of `**kwargs`.
- `benchmark`
Stock index symbol, str, or list type, the default value is `None`.
.. note::
* If `benchmark` is None, it will use the average change of the day of all stocks in 'pred' as the 'bench'.
* If `benchmark` is list, it will use the daily average change of the stock pool in the list as the 'bench'.
* If `benchmark` is str, it will use the daily change as the 'bench'.
- `account`
Backtest initial cash, integer type. The `account` in `strategy` section is deprecated. It only works when `account` is not set in `backtest` section. It will be overridden by `account` in the `backtest` section. The default value is 1e9.
- `deal_price`
Order transaction price field, str type, the default value is close.
- `min_cost`
Min transaction cost, float type, the default value is 5.
- `subscribe_fields`
Subscribe quote fields, array type, the default value is [`deal_price`, $close, $change, $factor].
Qlib Data Section
--------------------
The `qlib_data` field describes the parameters of qlib initialization.
.. code-block:: YAML
qlib_data:
# when testing, please modify the following parameters according to the specific environment
provider_uri: "~/.qlib/qlib_data/cn_data"
region: "cn"
- `provider_uri`
Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
- `region`
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
- `redis_host`
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
The lock and cache mechanism relies on redis.
- `redis_port`
Type: int, optional parameter(default: 6379), port of `redis`
.. note::
The value of `region` should be aligned with the data stored in `provider_uri`. Currently, ``scripts/get_data.py`` only provides China stock market data. If users want to use the US stock market data, they should prepare their own US-stock data in `provider_uri` and switch to US-stock mode.
.. note::
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <data.html#cache>`_ for details.
Please refer to `Initialization <../start/initialization.html>`_.
Experiment Result
===================
Form of Experimental Result
----------------------------
The result of the experiment is also the result of the ``Intraday Trading(Backtest)``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
Get Experiment Result
----------------------------
Base Class & Interface
~~~~~~~~~~~~~~~~~~~~~~~
Users can check the experiment results from file storage directly, or check the experiment results from the database, or get the experiment results through two interfaces of a base class `Fetcher` provided by ``Qlib``.
The `Fetcher` provides the following interface
- `get_experiments(self, exp_name=None):`
The interface takes one parameters. The `exp_name` is the experiment name, the default is all experiments. Users can get the returned dictionary with a list of ids and test end date as follows.
.. code-block:: JSON
{
"ex_a": [
{
"id": 1,
"test_end_date": "2017-01-01"
}
],
"ex_b": [
...
]
}
- `get_experiment(exp_name, exp_id, fields=None)`
The interface takes three parameters. The first parameter is the experiment name, the second parameter is the experiment id, and the third parameter is a list of fields. The default value of `fields` is None, which means all fields.
.. note::
Currently supported fields:
['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
Users can get the returned dictionary as follows.
.. code-block:: JSON
{
'analysis': analysis_df,
'pred': pred_df,
'positions': positions_dic,
'report_normal': report_normal_df,
}
Implemented `Fetcher` s & Examples
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``Qlib`` provides two implemented `Fetcher` s as follows.
`FileFetcher`
^^^^^^^^^^^^^^^
The `FileFetcher` is a subclass of `Fetcher`, which could fetch files from `file_storage` observer. The following is an example:
.. code-block:: python
>>> from qlib.contrib.estimator.fetcher import FileFetcher
>>> f = FileFetcher(experiments_dir=r'./')
>>> print(f.get_experiments())
{
'test_experiment': [
{
'id': '1',
'config': ...
},
{
'id': '2',
'config': ...
},
{
'id': '3',
'config': ...
}
]
}
>>> print(f.get_experiment('test_experiment', '1'))
risk
excess_return_without_cost mean 0.000605
std 0.005481
annualized_return 0.152373
information_ratio 1.751319
max_drawdown -0.059055
excess_return_with_cost mean 0.000410
std 0.005478
annualized_return 0.103265
information_ratio 1.187411
max_drawdown -0.075024
`MongoFetcher`
^^^^^^^^^^^^^^^
The `FileFetcher` is a subclass of `Fetcher`, which could fetch files from `mongo` observer. Users should initialize the fetcher with `mongo_url`. The following is an example:
.. code-block:: python
>>> from qlib.contrib.estimator.fetcher import MongoFetcher
>>> f = MongoFetcher(mongo_url=..., db_name=...)

View File

@@ -1,4 +1,5 @@
.. _model:
============================================
Interday Model: Model Training & Prediction
============================================
@@ -6,164 +7,138 @@ Interday Model: Model Training & Prediction
Introduction
===================
``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``Estimator``, please refer to `Estimator: Workflow Management <estimator.html>`_.
``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_.
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Interday Model`` can be used as an independent module also.
Base Class & Interface
======================
``Qlib`` provides a base class `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ from which all models should inherit.
``Qlib`` provides a base class `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ from which all models should inherit.
The base class provides the following interfaces:
- `__init__(**kwargs)`
- Initialization.
- If users use ``Estimator`` to start an `experiment`, the parameter of `__init__` method shoule be consistent with the hyperparameters in the configuration file.
- `fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs)`
- `fit(self, dataset, **kwargs)`
- Train model.
- Parameter:
- `x_train`, pd.DataFrame type, train feature
The following example explains the value of `x_train`:
- `dataset`, ``Qlib``'s ``DatasetH`` type. For more information about ``DatasetH``, users can refer to the related document: `Qlib Dataset <../component/data.html#dataset>`_.
The `dataset` is passed into the `model`'s method because there are some unique data preprocessing procedures for each, we want to give each model maximum flexibility to handle the data that is suitable for their own.
The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`:
.. code-block:: YAML
KMID KLEN KMID2 KUP KUP2
instrument datetime
SH600004 2012-01-04 0.000000 0.017685 0.000000 0.012862 0.727275
2012-01-05 -0.006473 0.025890 -0.250001 0.012945 0.499998
2012-01-06 0.008117 0.019481 0.416666 0.008117 0.416666
2012-01-09 0.016051 0.025682 0.624998 0.006421 0.250001
2012-01-10 0.017323 0.026772 0.647057 0.003150 0.117648
... ... ... ... ... ...
SZ300273 2014-12-25 -0.005295 0.038697 -0.136843 0.016293 0.421052
2014-12-26 -0.022486 0.041701 -0.539215 0.002453 0.058824
2014-12-29 -0.031526 0.039092 -0.806451 0.000000 0.000000
2014-12-30 -0.010000 0.032174 -0.310811 0.013913 0.432433
2014-12-31 0.010917 0.020087 0.543479 0.001310 0.065216
.. code-block:: Python
`x_train` is a pandas DataFrame, whose index is MultiIndex <instrument(str), datetime(pd.Timestamp)>. Each column of `x_train` corresponds to a feature, and the column name is the feature name.
.. note::
The number and names of the columns are determined by the data handler, please refer to `Data Handler <data.html#data-handler>`_ and `Estimator Data Section <estimator.html#data-section>`_.
- `y_train`, pd.DataFrame type, train label
The following example explains the value of `y_train`:
# get features and labels
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
.. code-block:: YAML
LABEL
instrument datetime
SH600004 2012-01-04 -0.798456
2012-01-05 -1.366716
2012-01-06 -0.491026
2012-01-09 0.296900
2012-01-10 0.501426
... ...
SZ300273 2014-12-25 -0.465540
2014-12-26 0.233864
2014-12-29 0.471368
2014-12-30 0.411914
2014-12-31 1.342723
`y_train` is a pandas DataFrame, whose index is MultiIndex <instrument(str), datetime(pd.Timestamp)>. The `LABEL` column represents the value of train label.
.. note::
The number and names of the columns are determined by the ``Data Handler``, please refer to `Data Handler <data.html#data-handler>`_.
- `x_valid`, pd.DataFrame type, validation feature
The format of `x_valid` is same as `x_train`
- `y_valid`, pd.DataFrame type, validation label
The format of `y_valid` is same as `y_train`
- `w_train`(Optional args, default is None), pd.DataFrame type, train weight
`w_train` is a pandas DataFrame, whose shape and index is same as `x_train`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
- `w_train`(Optional args, default is None), pd.DataFrame type, validation weight
`w_train` is a pandas DataFrame, whose shape and index is the same as `x_valid`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
- `predict(self, x_test, **kwargs)`
- Predict test data 'x_test'
- Parameter:
- `x_test`, pd.DataFrame type, test features
The form of `x_test` is same as `x_train` in 'fit' method.
- Return:
- `label`, np.ndarray type, test label
The label of `x_test` that predicted by model.
- `score(self, x_test, y_test, w_test=None, **kwargs)`
- Evaluate model with test feature/label
- Parameter:
- `x_test`, pd.DataFrame type, test feature
The format of `x_test` is same as `x_train` in `fit` method.
# get weights
try:
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
except KeyError as e:
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
- `x_test`, pd.DataFrame type, test label
The format of `y_test` is same as `y_train` in `fit` method.
- `predict(self, dataset, **kwargs)`
- Predict test data.
- Parameter:
- `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
- Returns:
- Predic results with type: `pandas.Series`.
- `w_test`, pd.DataFrame type, test weight
The format of `w_test` is same as `w_train` in `fit` method.
- Return: float type, evaluation score
- `finetune(self, dataset, **kwargs)`
- Finetune the model.
- Parameter:
- `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
For other interfaces such as `save`, `load`, `finetune`, please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
For other interfaces such as `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
Example
==================
``Qlib`` provides ``LightGBM`` and ``DNN`` models as the baseline, the following steps show how to run`` LightGBM`` as an independent module.
``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. The following steps show how to run`` LightGBM`` as an independent module.
- Initialize ``Qlib`` with `qlib.init` first, please refer to `Initialization <../start/initialization.html>`_.
- Run the following code to get the `prediction score` `pred_score`
.. code-block:: Python
from qlib.contrib.estimator.handler import Alpha158
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.utils import init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
DATA_HANDLER_CONFIG = {
"dropna_label": True,
"start_date": "2007-01-01",
"end_date": "2020-08-01",
"market": MARKET,
market = "csi300"
benchmark = "SH000300"
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
TRAINER_CONFIG = {
"train_start_date": "2007-01-01",
"train_end_date": "2014-12-31",
"validate_start_date": "2015-01-01",
"validate_end_date": "2016-12-31",
"test_start_date": "2017-01-01",
"test_end_date": "2020-08-01",
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(
**DATA_HANDLER_CONFIG
).get_split_data(**TRAINER_CONFIG)
# start exp
with R.start(experiment_name="workflow"):
# train
R.log_params(**flatten_dict(task))
model.fit(dataset)
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
MODEL_CONFIG = {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
}
# use default model
model = LGBModel(**MODEL_CONFIG)
model.fit(x_train, y_train, x_validate, y_validate)
_pred = model.predict(x_test)
pred_score = pd.DataFrame(index=_pred.index)
pred_score["score"] = _pred.iloc(axis=1)[0]
.. note:: `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.
.. note::
`Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.
`SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow <recorder.html#record-template>`_.
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
@@ -175,4 +150,4 @@ Qlib supports custom models. If users are interested in customizing their own mo
API
===================
Please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.

View File

@@ -0,0 +1,97 @@
.. _recorder:
====================================
Qlib Recorder: Experiment Management
====================================
.. currentmodule:: qlib
Introduction
===================
``Qlib`` contains an experiment management system named ``QlibRecorder``, which is designed to help users handle experiment and analysis results in an efficient way.
There are three components of the system:
- `ExperimentManager`
a class that manages experiments.
- `Experiment`
a class of experiment, and each instance of it is responsible for a single experiment.
- `Recorder`
a class of recorder, and each instance of it is responsible for a single run.
Here is a general view of the structure of the system:
.. code-block::
ExperimentManager
- Experiment 1
- Recorder 1
- Recorder 2
- ...
- Experiment 2
- Recorder 1
- Recorder 2
- ...
- ...
Currently, the components of this experiment management system are implemented using the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
Qlib Recorder
===================
``QlibRecorder`` provides a high level API for users to use the experiment management system. The interfaces are wrapped in the variable ``R`` in ``Qlib``, and users can directly use ``R`` to interact with the system. The following command shows how to import ``R`` in Python:
.. code-block:: Python
from qlib.workflow import R
``QlibRecorder`` includes several common API for managing `experiments` and `recorders` within a workflow. For more available APIs, please refer to the following section about `Experiment Manager`, `Experiment` and `Recorder`.
Here are the available interfaces of ``QlibRecorder``:
.. autoclass:: qlib.workflow.__init__.QlibRecorder
:members:
Experiment Manager
===================
The ``ExpManager`` module in ``Qlib`` is responsible for managing different experiments. Most of the APIs of ``ExpManager`` are similar to ``QlibRecorder``, and the most important API will be the ``get_exp`` method. User can directly refer to the documents above for some detailed information about how to use the ``get_exp`` method.
.. autoclass:: qlib.workflow.expm.ExpManager
:members: get_exp, list_experiments
For other interfaces such as `create_exp`, `delete_exp`, please refer to `Experiment Manager API <../reference/api.html#experiment-manager>`_.
Experiment
===================
The ``Experiment`` class is solely responsible for a single experiment, and it will handle any operations that are related to an experiment. Basic methods such as `start`, `end` an experiment are included. Besides, methods related to `recorders` are also available: such methods include `get_recorder` and `list_recorders`.
.. autoclass:: qlib.workflow.exp.Experiment
:members: get_recorder, list_recorders
For other interfaces such as `search_records`, `delete_recorder`, please refer to `Experiment API <../reference/api.html#experiment>`_.
Recorder
===================
The ``Recorder`` class is responsible for a single recorder. It will handle some detailed operations such as ``log_metrics``, ``log_params`` of a single run. It is designed to help user to easily track results and things being generated during a run.
Here are some important APIs that are not included in the ``QlibRecorder``:
.. autoclass:: qlib.workflow.recorder.Recorder
:members: list_artifacts, list_metrics, list_params, list_tags
For other interfaces such as `save_objects`, `load_object`, please refer to `Recorder API <../reference/api.html#recorder>`_.
Record Template
===================
The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
- ``SignalRecord``: This class generates the `preidction` results of the model.
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.

View File

@@ -1,12 +1,13 @@
.. _report:
==========================================
Aanalysis: Evaluation & Results Analysis
Analysis: Evaluation & Results Analysis
==========================================
Introduction
===================
``Aanalysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
``Analysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
- analysis_position
- report_graph

View File

@@ -1,4 +1,5 @@
.. _strategy:
========================================
Interday Strategy: Portfolio Management
========================================

280
docs/component/workflow.rst Normal file
View File

@@ -0,0 +1,280 @@
.. _workflow:
=================================
Workflow: Workflow Management
=================================
.. currentmodule:: qlib
Introduction
===================
The components in `Qlib Framework <../introduction/introduction.html#framework>`_ are designed in a loosely-coupled way. Users could build their own Quant research workflow with these components like `Example <https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py>`_.
Besides, ``Qlib`` provides more user-friendly interfaces named ``qrun`` to automatically run the whole workflow defined by configuration. A concrete execution of the whole workflow is called an `experiment`.
With ``qrun``, user can easily run an `experiment`, which includes the following steps:
- Data
- Loading
- Processing
- Slicing
- Model
- Training and inference
- Saving & loading
- Evaluation
- Forecast signal analysis
- Backtest
For each `experiment`, ``Qlib`` has a complete system to tracking all the information as well as artifacts generated during training, inference and evaluation phase. For more information about how Qlib handles `experiment`, please refer to the related document: `Recorder: Experiment Management <../component/recorder.html>`_.
Complete Example
===================
Before getting into details, here is a complete example of ``qrun``, which defines the workflow in typical Quant research.
Below is a typical config file of ``qrun``.
.. code-block:: YAML
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
.. code-block:: bash
qrun -c configuration.yaml
.. note::
`qrun` will be placed in your $PATH directory when installing ``Qlib``.
Configuration File
===================
Let's get into details of ``qrun`` in this section.
Before using ``qrun``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.
Qlib Data Section
--------------------
At first, the configuration file needs to contain several basic parameters about the data, which will be used for qlib initialization, data handling and backtest.
.. code-block:: YAML
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
The meaning of each field is as follows:
- `provider_uri`
Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
- `region`
- If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
- If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
.. note::
The value of `region` should be aligned with the data stored in `provider_uri`.
- `market`
Type: str. Index name, the default value is `csi500`.
- `benchmark`
Type: str, list or pandas.Series. Stock index symbol, the default value is `SH000905`.
.. note::
* If `benchmark` is str, it will use the daily change as the 'bench'.
* If `benchmark` is list, it will use the daily average change of the stock pool in the list as the 'bench'.
* If `benchmark` is pandas.Series, whose `index` is trading date and the value T is the change from T-1 to T, it will be directly used as the 'bench'. An example is as following:
.. code-block:: python
print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
2017-01-04 0.011693
2017-01-05 0.000721
2017-01-06 -0.004322
2017-01-09 0.006874
2017-01-10 -0.003350
.. note::
The symbol `&` in `yaml` file stands for an anchor of a field, which is useful when another fields include this parameter as part of the value. Taking the configuration file above as an example, users can directly change the value of `market` and `benchmark` without traversing the entire configuration file.
Model Section
--------------------
In the `task` field, the `model` section describes the parameters of the model to be used for training and inference. For more information about the base ``Model`` class, please refer to `Qlib Model <../component/model.html>`_.
.. code-block:: YAML
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
The meaning of each field is as follows:
- `class`
Type: str. The name for the model class.
- `module_path`
Type: str. The path for the model in qlib.
- `kwargs`
The keywords arguments for the model. Please refer to the specific model implementation for more information: `models <https://github.com/microsoft/qlib/blob/main/qlib/contrib/model>`_.
.. note::
``Qlib`` provides a util named: ``init_instance_by_config`` to initialize any class inside ``Qlib`` with the configuration includes the fields: `class`, `module_path` and `kwargs`.
Dataset Section
--------------------
The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Model <../component/data.html#dataset>`_.
The keywords arguments configuration of the ``DataHandler`` is as follows:
.. code-block:: YAML
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
Users can refer to the document of `DataHandler <../component/data.html#datahandler>`_ for more information about the meaning of each field in the configuration.
Here is the configuration for the ``Dataset`` module which will take care of data preprossing and slicing during the training and testing phase.
.. code-block:: YAML
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
Record Section
--------------------
The `record` field is about the parameters the ``Record`` module in ``Qlib``. ``Record`` is responsible for generating certain analysis and evaluation results such as `prediction`, `information Coefficient (IC)` and `backtest`.
The following script is the configuration of `backtest` and the `strategy` used in `backtest`:
.. code-block:: YAML
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
For more information about the meaning of each field in configuration of `strategy` and `backtest`, users can look up the documents: `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
Here is the configuration details of different `Record Template` such as ``SignalRecord`` and ``PortAnaRecord``:
.. code-block:: YAML
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
For more information about the ``Record`` module in ``Qlib``, user can refer to the related document: `Record <../component/recorder.html#record-template>`_.

View File

@@ -124,7 +124,7 @@ html_theme_options = {
"logo_only": True,
"collapse_navigation": False,
"display_version": False,
"navigation_depth": 3,
"navigation_depth": 4,
}
# Add any paths that contain custom static files (such as style sheets) here,

View File

@@ -35,12 +35,13 @@ Document Structure
:maxdepth: 3
:caption: COMPONENTS:
Estimator: Workflow Management <component/estimator.rst>
Workflow: Workflow Management <component/workflow.rst>
Data Layer: Data Framework&Usage <component/data.rst>
Interday Model: Model Training & Prediction <component/model.rst>
Interday Strategy: Portfolio Management <component/strategy.rst>
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
Aanalysis: Evaluation & Results Analysis <component/report.rst>
Qlib Recorder: Experiment Management <component/recorder.rst>
Analysis: Evaluation & Results Analysis <component/report.rst>
.. toctree::
:maxdepth: 3
@@ -48,6 +49,7 @@ Document Structure
Building Formulaic Alphas <advanced/alpha.rst>
Online & Offline mode <advanced/server.rst>
.. toctree::
:maxdepth: 3
:caption: REFERENCE:

View File

@@ -21,27 +21,27 @@ Framework
At the module level, Qlib is a platform that consists of above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
====================== ==============================================================================
Name Description
====================== ==============================================================================
`Data layer` `DataServer` focuses on providing high-performance infrastructure for users to
manage and retrieve raw data. `DataEnhancement` will preprocess the data and
provide the best dataset to be fed into the models.
`Interday Model` `Interday model` focuses on producing prediction scores (aka. `alpha`). Models
are trained by `Model Creator` and managed by `Model Manager`. Users could
choose one or multiple models for prediction. Multiple models could be combined
with `Ensemble` module.
`Interday Strategy` `Portfolio Generator` will take prediction scores as input and output the
orders based on the current position to achieve the target portfolio.
======================== ==============================================================================
Name Description
======================== ==============================================================================
`Infrastructure` layer `Infrastructure` layer provides underlying support for Quant research.
`DataServer` provides high-performance infrastructure for users to manage
and retrieve raw data. `Trainer` provides flexible interface to control
the training process of models which enable algorithms controlling the
training process.
`Intraday Trading` `Order Executor` is responsible for executing orders output by
`Interday Strategy` and returning the executed results.
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
`Information Extractor` extracts data for models. `Forecast Model` focuses
on producing all kinds of forecast signals (e.g. _alpha_, risk) for other
modules. With these signals `Portfolio Generator` will generate the target
portfolio and produce orders to be executed by `Order Executor`.
`Analysis` Users could get a detailed analysis report of forecasting signals and portfolios
in this part.
====================== ==============================================================================
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
system. `Analyser` module will provide users detailed analysis reports of
forecasting signals, portfolios and execution results
======================== ==============================================================================
- The modules with hand-drawn style are under development and will be released in the future.
- The modules with dashed borders are highly user-customizable and extendible.

View File

@@ -49,18 +49,19 @@ To kown more about `prepare data`, please refer to `Data Preparation <../compone
Auto Quant Research Workflow
====================================
``Qlib`` provides a tool named ``Estimator`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
``Qlib`` provides a tool named ``qrun`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
- Quant Research Workflow:
- Run ``Estimator`` with `estimator_config.yaml` as following.
- Run ``qrun`` with a config file of the LightGBM model `workflow_config_lightgbm.yaml` as following.
.. code-block::
cd examples # Avoid running program under the directory contains `qlib`
estimator -c estimator/estimator_config.yaml
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
- Estimator result
The result of ``Estimator`` is as follows, which is also the result of ``Intraday Trading``. Please refer to `Intraday Trading <../component/backtest.html>`_. for more details about the result.
- Workflow result
The result of ``qrun`` is as follows, which is also the typical result of ``Forecast model(alpha)``. Please refer to `Intraday Trading <../component/backtest.html>`_. for more details about the result.
.. code-block:: python
@@ -77,17 +78,17 @@ Auto Quant Research Workflow
max_drawdown -0.075024
To know more about `Estimator`, please refer to `Estimator: Workflow Management <../component/estimator.html>`_.
To know more about `workflow` and `qrun`, please refer to `Workflow: Workflow Management <../component/workflow.html>`_.
- Graphical Reports Analysis:
- Run ``examples/estimator/analyze_from_estimator.ipynb`` with jupyter notebook
Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/estimator/analyze_from_estimator.ipynb``.
- Run ``examples/workflow_by_code.ipynb`` with jupyter notebook
Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/workflow_by_code.ipynb``.
- Graphical Reports
Users can get graphical reports about the analysis, please refer to `Aanalysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
Users can get graphical reports about the analysis, please refer to `Analysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
Custom Model Integration
===============================================
``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. If users are interested in the custom model, please refer to `Custom Model Integration <../start/integration.html>`_.
``Qlib`` provides a batch of models (such as ``lightGBM`` and ``MLP`` models) as examples of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. If users are interested in the custom model, please refer to `Custom Model Integration <../start/integration.html>`_.

View File

@@ -23,16 +23,13 @@ Filter
.. automodule:: qlib.data.filter
:members:
Feature
--------------------
Class
~~~~~~~~~~~~~~~~~~~~
--------------------
.. automodule:: qlib.data.base
:members:
Operator
~~~~~~~~~~~~~~~~~~~~
--------------------
.. automodule:: qlib.data.ops
:members:
@@ -56,19 +53,36 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:
Dataset
---------------
Dataset Class
~~~~~~~~~~~~~~~~~~~~
.. automodule:: qlib.data.dataset.__init__
:members:
Data Loader
~~~~~~~~~~~~~~~~~~~~
.. automodule:: qlib.data.dataset.loader
:members:
Data Handler
~~~~~~~~~~~~~~~~~~~~
.. automodule:: qlib.data.dataset.handler
:members:
Processor
~~~~~~~~~~~~~~~~~~~~
.. automodule:: qlib.data.dataset.processor
:members:
Contrib
====================
Data Handler
---------------
.. automodule:: qlib.contrib.estimator.handler
:members:
Model
--------------------
.. automodule:: qlib.contrib.model.base
.. automodule:: qlib.model.base
:members:
Strategy
@@ -116,3 +130,26 @@ Report
:members:
Workflow
====================
Experiment Manager
--------------------
.. autoclass:: qlib.workflow.expm.ExpManager
:members:
Experiment
--------------------
.. autoclass:: qlib.workflow.exp.Experiment
:members:
Recorder
--------------------
.. autoclass:: qlib.workflow.recorder.Recorder
:members:
Record Template
--------------------
.. automodule:: qlib.workflow.record_temp
:members:

View File

@@ -1,4 +1,5 @@
.. _getdata:
=============================
Data Retrieval
=============================

View File

@@ -1,4 +1,5 @@
.. _initialization:
====================
Qlib Initialization
====================
@@ -11,14 +12,16 @@ Initialization
Please follow the steps below to initialize ``Qlib``.
- Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>` for more information about customized dataset.
Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>`_ for more information about customized dataset.
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
- Initialize Qlib before calling other APIs: run following code in python.
Initialize Qlib before calling other APIs: run following code in python.
.. code-block:: Python
@@ -58,3 +61,16 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
.. note::
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
- `exp_manager`
Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.
.. code-block:: Python
# For example, if you want to set your tracking_uri to a <specific folder>, you can initialize qlib below
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager= {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {
"uri": "python_execution_path/mlruns",
"default_exp_name": "Experiment",
}
})

View File

@@ -1,4 +1,5 @@
.. _installation:
====================
Installation
====================

View File

@@ -5,17 +5,17 @@ Custom Model Integration
Introduction
===================
``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``.
``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are examples of ``Interday Model``. In addition to the default models ``Qlib`` provide, users can integrate their own custom models into ``Qlib``.
Users can integrate their own custom models according to the following steps.
- Define a custom model class, which should be a subclass of the `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_.
- Define a custom model class, which should be a subclass of the `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_.
- Write a configuration file that describes the path and parameters of the custom model.
- Test the custom model.
Custom Model Class
===========================
The Custom models need to inherit `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ and override the methods in it.
The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ and override the methods in it.
- Override the `__init__` method
- ``Qlib`` passes the initialized parameters to the \_\_init\_\_ method.
@@ -32,79 +32,77 @@ The Custom models need to inherit `qlib.contrib.model.base.Model <../reference/a
- Override the `fit` method
- ``Qlib`` calls the fit method to train the model
- The parameters must include training feature `x_train`, training label `y_train`, test feature `x_valid`, test label `y_valid` at least.
- The parameters could include some optional parameters with default values, such as train weight `w_train`, test weight `w_valid` and `num_boost_round = 1000`.
- The parameters must include training feature `dataset`.
- The parameters could include some optional parameters with default values, such as `num_boost_round = 1000` for `GBDT`.
- Code Example: In the following example, `num_boost_round = 1000` is an optional parameter.
.. code-block:: Python
def fit(self, x_train:pd.DataFrame, y_train:pd.DataFrame, x_valid:pd.DataFrame, y_valid:pd.DataFrame,
w_train:pd.DataFrame = None, w_valid:pd.DataFrame = None, num_boost_round = 1000, **kwargs):
def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
# prepare dataset for lgb training and evaluation
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
raise ValueError('LightGBM doesn\'t support multi-label training')
raise ValueError("LightGBM doesn't support multi-label training")
w_train_weight = None if w_train is None else w_train.values
w_valid_weight = None if w_valid is None else w_valid.values
dtrain = lgb.Dataset(x_train.values, label=y_train)
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
self._model = lgb.train(
self._params,
dtrain,
# fit the model
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=['train', 'valid'],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
)
- Override the `predict` method
- The parameters include the test features.
- The parameters must include training feature `dataset`, which will be userd to get the test dataset.
- Return the `prediction score`.
- Please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_ for the parameter types of the fit method.
- Code Example: In the following example, users need to use dnn to predict the label(such as `preds`) of test data `x_test` and return it.
- Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
- Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
.. code-block:: Python
def predict(self, x_test:pd.DataFrame, **kwargs)-> numpy.ndarray:
if self._model is None:
raise ValueError('model is not fitted yet!')
return self._model.predict(x_test.values)
def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
- Override the `save` method & `load` method
- The `save` method parameter includes the a `filename` that represents an absolute path, user need to save model into the path.
- The `load` method parameter includes the a `buffer` read from the `filename` passed in the `save` method, users need to load model from the `buffer`.
- Code Example:
- Override the `finetune` method
- The parameters must include training feature `dataset`.
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
.. code-block:: Python
def save(self, filename):
if self._model is None:
raise ValueError('model is not fitted yet!')
self._model.save_model(filename)
def load(self, buffer):
self._model = lgb.Booster(params={'model_str': buffer.decode('utf-8')})
.. Without tuner, this part will not be used
.. - Override the `score` method(This step is optional)
.. - The parameters include the test features and test labels.
.. - Return the evaluation score of the model. It's recommended to adopt the loss between labels and `prediction score`.
.. - Code Example: In the following example, users need to calculate the weighted loss with test data `x_test`, test label `y_test` and the weight `w_test`.
.. .. code-block:: Python
..
.. def score(self, x_test:pd.Dataframe, y_test:pd.Dataframe, w_test:pd.DataFrame = None) -> float:
.. # Remove rows from x, y and w, which contain Nan in any columns in y_test.
.. x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
.. preds = self.predict(x_test)
.. w_test_weight = None if w_test is None else w_test.values
.. scorer = mean_squared_error if self.loss_type == 'mse' else roc_auc_score
.. return scorer(y_test.values, preds, sample_weight=w_test_weight)
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval,
)
Configuration File
=======================
The configuration file is described in detail in the `estimator <../component/estimator.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file.
The configuration file is described in detail in the `Workflow <../component/workflow.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file. The configuration describes which models to use and how we can initialize it.
- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`.
@@ -124,23 +122,23 @@ The configuration file is described in detail in the `estimator <../component/es
num_leaves: 210
num_threads: 20
Users could find configuration file of the baseline of the ``Model`` in ``qlib/examples/estimator/estimator_config.yaml`` and ``qlib/examples/estimator/estimator_config_dnn.yaml``
Users could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder.
Model Testing
=====================
Assuming that the configuration file is ``examples/estimator/estimator_config.yaml``, users can run the following command to test the custom model:
Assuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml``, users can run the following command to test the custom model:
.. code-block:: bash
cd examples # Avoid running program under the directory contains `qlib`
estimator -c estimator/estimator_config.yaml
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
.. note:: ``estimator`` is a built-in command of ``Qlib``.
.. note:: ``qrun`` is a built-in command of ``Qlib``.
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/train_backtest_analyze.ipynb``.
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``.
Reference
=====================
To know more about ``Interday Model``, please refer to `Interday Model: Model Training & Prediction <../component/model.html>`_ and `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
To know more about ``Interday Model``, please refer to `Interday Model: Model Training & Prediction <../component/model.html>`_ and `Model API <../reference/api.html#module-qlib.model.base>`_.

View File

@@ -0,0 +1,8 @@
# ALSTM
- ALSTM contains a temporal attentive aggregation layer based on normal LSTM.
- Paper: A dual-stage attention-based recurrent neural network for time series prediction.
[https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)

View File

@@ -0,0 +1,4 @@
numpy==1.17.4
pandas==1.1.2
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,83 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: ALSTM
module_path: qlib.contrib.model.pytorch_alstm
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 800
metric: loss
loss: mse
seed: 0
GPU: 0
rnn_type: GRU
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
# CatBoost
* Code: [https://github.com/catboost/catboost](https://github.com/catboost/catboost)
* Paper: CatBoost: unbiased boosting with categorical features. [https://proceedings.neurips.cc/paper/2018/file/14491b756b3a51daac41c24863285549-Paper.pdf](https://proceedings.neurips.cc/paper/2018/file/14491b756b3a51daac41c24863285549-Paper.pdf).

View File

@@ -0,0 +1,3 @@
pandas==1.1.2
numpy==1.17.4
catboost==0.24.3

View File

@@ -0,0 +1,64 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: CatBoostModel
module_path: qlib.contrib.model.catboost_model
kwargs:
loss: RMSE
learning_rate: 0.0421
subsample: 0.8789
max_depth: 6
num_leaves: 100
thread_count: 20
grow_policy: Lossguide
bootstrap_type: Poisson
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,5 @@
# GATs
* Graph Attention Networks(GATs) leverage masked self-attentional layers on graph-structured data. The nodes in stacked layers have different weights and they are able to attend over their
neighborhoods features, without requiring any kind of costly matrix operation (such as inversion) or depending on knowing the graph structure upfront.
* This code used in Qlib is implemented with PyTorch by ourselves.
* Paper: Graph Attention Networks https://arxiv.org/pdf/1710.10903.pdf

View File

@@ -0,0 +1,4 @@
pandas==1.1.2
numpy==1.17.4
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,77 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GATs
module_path: qlib.contrib.model.pytorch_gats
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.7
n_epochs: 200
lr: 1e-4
early_stop: 20
metric: loss
loss: mse
base_model: LSTM
seed: 0
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

Binary file not shown.

View File

@@ -0,0 +1,4 @@
numpy==1.17.4
pandas==1.1.2
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,82 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: GRU
module_path: qlib.contrib.model.pytorch_gru
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 800
metric: loss
loss: mse
seed: 0
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

Binary file not shown.

View File

@@ -0,0 +1,4 @@
numpy==1.17.4
pandas==1.1.2
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,82 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LSTM
module_path: qlib.contrib.model.pytorch_lstm
kwargs:
d_feat: 6
hidden_size: 64
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
early_stop: 20
batch_size: 800
metric: loss
loss: mse
seed: 0
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,4 @@
# LightGBM
* Code: [https://github.com/microsoft/LightGBM](https://github.com/microsoft/LightGBM)
* Paper: LightGBM: A Highly Efficient Gradient Boosting
Decision Tree. [https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf).

View File

@@ -0,0 +1,3 @@
pandas==1.1.2
numpy==1.17.4
lightgbm==3.1.0

View File

@@ -0,0 +1,65 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
numpy>=1.17.4
pandas>=1.0.1
scikit-learn>=0.23.1

View File

@@ -0,0 +1,71 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LinearModel
module_path: qlib.contrib.model.linear
kwargs:
estimator: ols
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: True
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,4 @@
pandas==1.1.2
numpy==1.17.4
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,93 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors: [
{
"class" : "DropCol",
"kwargs":{"col_list": ["VWAP0"]}
},
{
"class" : "CSZFillna",
"kwargs":{"fields_group": "feature"}
}
]
learn_processors: [
{
"class" : "DropCol",
"kwargs":{"col_list": ["VWAP0"]}
},
{
"class" : "DropnaProcessor",
"kwargs":{"fields_group": "feature"}
},
"DropnaLabel",
{
"class": "CSZScoreNorm",
"kwargs": {"fields_group": "label"}
}
]
process_type: "independent"
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: DNNModelPytorch
module_path: qlib.contrib.model.pytorch_nn
kwargs:
loss: mse
input_dim: 157
output_dim: 1
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: adam
max_steps: 8000
batch_size: 4096
GPU: 0
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
# State-Frequency-Memory
- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.)

View File

@@ -0,0 +1,4 @@
pandas==1.1.2
numpy==1.17.4
scikit_learn==0.23.2
torch==1.7.0

View File

@@ -0,0 +1,85 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
infer_processors:
- class: RobustZScoreNorm
kwargs:
fields_group: feature
clip_outlier: true
- class: Fillna
kwargs:
fields_group: feature
learn_processors:
- class: DropnaLabel
- class: CSRankNorm
kwargs:
fields_group: label
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: SFM
module_path: qlib.contrib.model.pytorch_sfm
kwargs:
d_feat: 6
hidden_size: 64
output_dim: 32
freq_dim: 25
dropout_W: 0.5
dropout_U: 0.5
n_epochs: 20
lr: 1e-3
batch_size: 1600
early_stop: 20
eval_steps: 5
loss: mse
optimizer: adam
GPU: 1
seed: 710
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,14 @@
# Temporal Fusion Transformers Benchmark
## Source
**Reference**: Lim, Bryan, et al. "Temporal fusion transformers for interpretable multi-horizon time series forecasting." arXiv preprint arXiv:1912.09363 (2019).
**GitHub**: https://github.com/google-research/google-research/tree/master/tft
## Run the Workflow
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
### Notes
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
3. The model must run in GPU, or an error will be raised.
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.

View File

@@ -0,0 +1,14 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,223 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Default data formatting functions for experiments.
For new datasets, inherit form GenericDataFormatter and implement
all abstract functions.
These dataset-specific methods:
1) Define the column and input types for tabular dataframes used by model
2) Perform the necessary input feature engineering & normalisation steps
3) Reverts the normalisation for predictions
4) Are responsible for train, validation and test splits
"""
import abc
import enum
# Type defintions
class DataTypes(enum.IntEnum):
"""Defines numerical types of each column."""
REAL_VALUED = 0
CATEGORICAL = 1
DATE = 2
class InputTypes(enum.IntEnum):
"""Defines input types of each column."""
TARGET = 0
OBSERVED_INPUT = 1
KNOWN_INPUT = 2
STATIC_INPUT = 3
ID = 4 # Single column used as an entity identifier
TIME = 5 # Single column exclusively used as a time index
class GenericDataFormatter(abc.ABC):
"""Abstract base class for all data formatters.
User can implement the abstract methods below to perform dataset-specific
manipulations.
"""
@abc.abstractmethod
def set_scalers(self, df):
"""Calibrates scalers using the data supplied."""
raise NotImplementedError()
@abc.abstractmethod
def transform_inputs(self, df):
"""Performs feature transformation."""
raise NotImplementedError()
@abc.abstractmethod
def format_predictions(self, df):
"""Reverts any normalisation to give predictions in original scale."""
raise NotImplementedError()
@abc.abstractmethod
def split_data(self, df):
"""Performs the default train, validation and test splits."""
raise NotImplementedError()
@property
@abc.abstractmethod
def _column_definition(self):
"""Defines order, input type and data type of each column."""
raise NotImplementedError()
@abc.abstractmethod
def get_fixed_params(self):
"""Defines the fixed parameters used by the model for training.
Requires the following keys:
'total_time_steps': Defines the total number of time steps used by TFT
'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
'num_epochs': Maximum number of epochs for training
'early_stopping_patience': Early stopping param for keras
'multiprocessing_workers': # of cpus for data processing
Returns:
A dictionary of fixed parameters, e.g.:
fixed_params = {
'total_time_steps': 252 + 5,
'num_encoder_steps': 252,
'num_epochs': 100,
'early_stopping_patience': 5,
'multiprocessing_workers': 5,
}
"""
raise NotImplementedError
# Shared functions across data-formatters
@property
def num_classes_per_cat_input(self):
"""Returns number of categories per relevant input.
This is seqeuently required for keras embedding layers.
"""
return self._num_classes_per_cat_input
def get_num_samples_for_calibration(self):
"""Gets the default number of training and validation samples.
Use to sub-sample the data for network calibration and a value of -1 uses
all available samples.
Returns:
Tuple of (training samples, validation samples)
"""
return -1, -1
def get_column_definition(self):
""""Returns formatted column definition in order expected by the TFT."""
column_definition = self._column_definition
# Sanity checks first.
# Ensure only one ID and time column exist
def _check_single_column(input_type):
length = len([tup for tup in column_definition if tup[2] == input_type])
if length != 1:
raise ValueError("Illegal number of inputs ({}) of type {}".format(length, input_type))
_check_single_column(InputTypes.ID)
_check_single_column(InputTypes.TIME)
identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
real_inputs = [
tup
for tup in column_definition
if tup[1] == DataTypes.REAL_VALUED and tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
categorical_inputs = [
tup
for tup in column_definition
if tup[1] == DataTypes.CATEGORICAL and tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
return identifier + time + real_inputs + categorical_inputs
def _get_input_columns(self):
"""Returns names of all input columns."""
return [tup[0] for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
def _get_tft_input_indices(self):
"""Returns the relevant indexes and input sizes required by TFT."""
# Functions
def _extract_tuples_from_data_type(data_type, defn):
return [tup for tup in defn if tup[1] == data_type and tup[2] not in {InputTypes.ID, InputTypes.TIME}]
def _get_locations(input_types, defn):
return [i for i, tup in enumerate(defn) if tup[2] in input_types]
# Start extraction
column_definition = [
tup for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}
]
categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, column_definition)
real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, column_definition)
locations = {
"input_size": len(self._get_input_columns()),
"output_size": len(_get_locations({InputTypes.TARGET}, column_definition)),
"category_counts": self.num_classes_per_cat_input,
"input_obs_loc": _get_locations({InputTypes.TARGET}, column_definition),
"static_input_loc": _get_locations({InputTypes.STATIC_INPUT}, column_definition),
"known_regular_inputs": _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs),
"known_categorical_inputs": _get_locations(
{InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, categorical_inputs
),
}
return locations
def get_experiment_params(self):
"""Returns fixed model parameters for experiments."""
required_keys = [
"total_time_steps",
"num_encoder_steps",
"num_epochs",
"early_stopping_patience",
"multiprocessing_workers",
]
fixed_params = self.get_fixed_params()
for k in required_keys:
if k not in fixed_params:
raise ValueError("Field {}".format(k) + " missing from fixed parameter definitions!")
fixed_params["column_definition"] = self.get_column_definition()
fixed_params.update(self._get_tft_input_indices())
return fixed_params

View File

@@ -0,0 +1,219 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Custom formatting functions for Alpha158 dataset.
Defines dataset specific column definitions and data transformations.
"""
import data_formatters.base
import libs.utils as utils
import sklearn.preprocessing
GenericDataFormatter = data_formatters.base.GenericDataFormatter
DataTypes = data_formatters.base.DataTypes
InputTypes = data_formatters.base.InputTypes
class Alpha158Formatter(GenericDataFormatter):
"""Defines and formats data for the Alpha158 dataset.
Attributes:
column_definition: Defines input and data type of column used in the
experiment.
identifiers: Entity identifiers used in experiments.
"""
_column_definition = [
("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
("date", DataTypes.DATE, InputTypes.TIME),
("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
# Selected 10 features
("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
]
def __init__(self):
"""Initialises formatter."""
self.identifiers = None
self._real_scalers = None
self._cat_scalers = None
self._target_scaler = None
self._num_classes_per_cat_input = None
def split_data(self, df, valid_boundary=2016, test_boundary=2018):
"""Splits data frame into training-validation-test data frames.
This also calibrates scaling object, and transforms data for each split.
Args:
df: Source data frame to split.
valid_boundary: Starting year for validation data
test_boundary: Starting year for test data
Returns:
Tuple of transformed (train, valid, test) data.
"""
print("Formatting train-valid-test splits.")
index = df["year"]
train = df.loc[index < valid_boundary]
valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
test = df.loc[index >= test_boundary]
self.set_scalers(train)
return (self.transform_inputs(data) for data in [train, valid, test])
def set_scalers(self, df):
"""Calibrates scalers using the data supplied.
Args:
df: Data to use to calibrate scalers.
"""
print("Setting scalers with training data...")
column_definitions = self.get_column_definition()
id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
# Extract identifiers in case required
self.identifiers = list(df[id_column].unique())
# Format real scalers
real_inputs = utils.extract_cols_from_data_type(
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
data = df[real_inputs].values
self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
df[[target_column]].values
) # used for predictions
# Format categorical scalers
categorical_inputs = utils.extract_cols_from_data_type(
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
categorical_scalers = {}
num_classes = []
for col in categorical_inputs:
# Set all to str so that we don't have mixed integer/string columns
srs = df[col].apply(str)
categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
num_classes.append(srs.nunique())
# Set categorical scaler outputs
self._cat_scalers = categorical_scalers
self._num_classes_per_cat_input = num_classes
def transform_inputs(self, df):
"""Performs feature transformations.
This includes both feature engineering, preprocessing and normalisation.
Args:
df: Data frame to transform.
Returns:
Transformed data frame.
"""
output = df.copy()
if self._real_scalers is None and self._cat_scalers is None:
raise ValueError("Scalers have not been set!")
column_definitions = self.get_column_definition()
real_inputs = utils.extract_cols_from_data_type(
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
categorical_inputs = utils.extract_cols_from_data_type(
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
)
# Format real inputs
output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
# Format categorical inputs
for col in categorical_inputs:
string_df = df[col].apply(str)
output[col] = self._cat_scalers[col].transform(string_df)
return output
def format_predictions(self, predictions):
"""Reverts any normalisation to give predictions in original scale.
Args:
predictions: Dataframe of model predictions.
Returns:
Data frame of unnormalised predictions.
"""
output = predictions.copy()
column_names = predictions.columns
for col in column_names:
if col not in {"forecast_time", "identifier"}:
output[col] = self._target_scaler.inverse_transform(predictions[col])
return output
# Default params
def get_fixed_params(self):
"""Returns fixed model parameters for experiments."""
fixed_params = {
"total_time_steps": 6 + 6,
"num_encoder_steps": 6,
"num_epochs": 100,
"early_stopping_patience": 10,
"multiprocessing_workers": 5,
}
return fixed_params
def get_default_model_params(self):
"""Returns default optimised model parameters."""
model_params = {
"dropout_rate": 0.4,
"hidden_layer_size": 16,
"learning_rate": 0.0001,
"minibatch_size": 128,
"max_gradient_norm": 0.0135,
"num_heads": 1,
"stack_size": 1,
}
return model_params

View File

@@ -0,0 +1,14 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,95 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Default configs for TFT experiments.
Contains the default output paths for data, serialised models and predictions
for the main experiments used in the publication.
"""
import os
import data_formatters.qlib_Alpha158
class ExperimentConfig(object):
"""Defines experiment configs and paths to outputs.
Attributes:
root_folder: Root folder to contain all experimental outputs.
experiment: Name of experiment to run.
data_folder: Folder to store data for experiment.
model_folder: Folder to store serialised models.
results_folder: Folder to store results.
data_csv_path: Path to primary data csv file used in experiment.
hyperparam_iterations: Default number of random search iterations for
experiment.
"""
default_experiments = ["Alpha158"]
def __init__(self, experiment="volatility", root_folder=None):
"""Creates configs based on default experiment chosen.
Args:
experiment: Name of experiment.
root_folder: Root folder to save all outputs of training.
"""
if experiment not in self.default_experiments:
raise ValueError("Unrecognised experiment={}".format(experiment))
# Defines all relevant paths
if root_folder is None:
root_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "outputs")
print("Using root folder {}".format(root_folder))
self.root_folder = root_folder
self.experiment = experiment
self.data_folder = os.path.join(root_folder, "data", experiment)
self.model_folder = os.path.join(root_folder, "saved_models", experiment)
self.results_folder = os.path.join(root_folder, "results", experiment)
# Creates folders if they don't exist
for relevant_directory in [self.root_folder, self.data_folder, self.model_folder, self.results_folder]:
if not os.path.exists(relevant_directory):
os.makedirs(relevant_directory)
@property
def data_csv_path(self):
csv_map = {
"Alpha158": "Alpha158.csv",
}
return os.path.join(self.data_folder, csv_map[self.experiment])
@property
def hyperparam_iterations(self):
return 240 if self.experiment == "volatility" else 60
def make_data_formatter(self):
"""Gets a data formatter object for experiment.
Returns:
Default DataFormatter per experiment.
"""
data_formatter_class = {
"Alpha158": data_formatters.qlib_Alpha158.Alpha158Formatter,
}
return data_formatter_class[self.experiment]()

View File

@@ -0,0 +1,14 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,430 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Classes used for hyperparameter optimisation.
Two main classes exist:
1) HyperparamOptManager used for optimisation on a single machine/GPU.
2) DistributedHyperparamOptManager for multiple GPUs on different machines.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import shutil
import libs.utils as utils
import numpy as np
import pandas as pd
Deque = collections.deque
class HyperparamOptManager:
"""Manages hyperparameter optimisation using random search for a single GPU.
Attributes:
param_ranges: Discrete hyperparameter range for random search.
results: Dataframe of validation results.
fixed_params: Fixed model parameters per experiment.
saved_params: Dataframe of parameters trained.
best_score: Minimum validation loss observed thus far.
optimal_name: Key to best configuration.
hyperparam_folder: Where to save optimisation outputs.
"""
def __init__(self, param_ranges, fixed_params, model_folder, override_w_fixed_params=True):
"""Instantiates model.
Args:
param_ranges: Discrete hyperparameter range for random search.
fixed_params: Fixed model parameters per experiment.
model_folder: Folder to store optimisation artifacts.
override_w_fixed_params: Whether to override serialsed fixed model
parameters with new supplied values.
"""
self.param_ranges = param_ranges
self._max_tries = 1000
self.results = pd.DataFrame()
self.fixed_params = fixed_params
self.saved_params = pd.DataFrame()
self.best_score = np.Inf
self.optimal_name = ""
# Setup
# Create folder for saving if its not there
self.hyperparam_folder = model_folder
utils.create_folder_if_not_exist(self.hyperparam_folder)
self._override_w_fixed_params = override_w_fixed_params
def load_results(self):
"""Loads results from previous hyperparameter optimisation.
Returns:
A boolean indicating if previous results can be loaded.
"""
print("Loading results from", self.hyperparam_folder)
results_file = os.path.join(self.hyperparam_folder, "results.csv")
params_file = os.path.join(self.hyperparam_folder, "params.csv")
if os.path.exists(results_file) and os.path.exists(params_file):
self.results = pd.read_csv(results_file, index_col=0)
self.saved_params = pd.read_csv(params_file, index_col=0)
if not self.results.empty:
self.results.at["loss"] = self.results.loc["loss"].apply(float)
self.best_score = self.results.loc["loss"].min()
is_optimal = self.results.loc["loss"] == self.best_score
self.optimal_name = self.results.T[is_optimal].index[0]
return True
return False
def _get_params_from_name(self, name):
"""Returns previously saved parameters given a key."""
params = self.saved_params
selected_params = dict(params[name])
if self._override_w_fixed_params:
for k in self.fixed_params:
selected_params[k] = self.fixed_params[k]
return selected_params
def get_best_params(self):
"""Returns the optimal hyperparameters thus far."""
optimal_name = self.optimal_name
return self._get_params_from_name(optimal_name)
def clear(self):
"""Clears all previous results and saved parameters."""
shutil.rmtree(self.hyperparam_folder)
os.makedirs(self.hyperparam_folder)
self.results = pd.DataFrame()
self.saved_params = pd.DataFrame()
def _check_params(self, params):
"""Checks that parameter map is properly defined."""
valid_fields = list(self.param_ranges.keys()) + list(self.fixed_params.keys())
invalid_fields = [k for k in params if k not in valid_fields]
missing_fields = [k for k in valid_fields if k not in params]
if invalid_fields:
raise ValueError("Invalid Fields Found {} - Valid ones are {}".format(invalid_fields, valid_fields))
if missing_fields:
raise ValueError("Missing Fields Found {} - Valid ones are {}".format(missing_fields, valid_fields))
def _get_name(self, params):
"""Returns a unique key for the supplied set of params."""
self._check_params(params)
fields = list(params.keys())
fields.sort()
return "_".join([str(params[k]) for k in fields])
def get_next_parameters(self, ranges_to_skip=None):
"""Returns the next set of parameters to optimise.
Args:
ranges_to_skip: Explicitly defines a set of keys to skip.
"""
if ranges_to_skip is None:
ranges_to_skip = set(self.results.index)
if not isinstance(self.param_ranges, dict):
raise ValueError("Only works for random search!")
param_range_keys = list(self.param_ranges.keys())
param_range_keys.sort()
def _get_next():
"""Returns next hyperparameter set per try."""
parameters = {k: np.random.choice(self.param_ranges[k]) for k in param_range_keys}
# Adds fixed params
for k in self.fixed_params:
parameters[k] = self.fixed_params[k]
return parameters
for _ in range(self._max_tries):
parameters = _get_next()
name = self._get_name(parameters)
if name not in ranges_to_skip:
return parameters
raise ValueError("Exceeded max number of hyperparameter searches!!")
def update_score(self, parameters, loss, model, info=""):
"""Updates the results from last optimisation run.
Args:
parameters: Hyperparameters used in optimisation.
loss: Validation loss obtained.
model: Model to serialised if required.
info: Any ancillary information to tag on to results.
Returns:
Boolean flag indicating if the model is the best seen so far.
"""
if np.isnan(loss):
loss = np.Inf
if not os.path.isdir(self.hyperparam_folder):
os.makedirs(self.hyperparam_folder)
name = self._get_name(parameters)
is_optimal = self.results.empty or loss < self.best_score
# save the first model
if is_optimal:
# Try saving first, before updating info
if model is not None:
print("Optimal model found, updating")
model.save(self.hyperparam_folder)
self.best_score = loss
self.optimal_name = name
self.results[name] = pd.Series({"loss": loss, "info": info})
self.saved_params[name] = pd.Series(parameters)
self.results.to_csv(os.path.join(self.hyperparam_folder, "results.csv"))
self.saved_params.to_csv(os.path.join(self.hyperparam_folder, "params.csv"))
return is_optimal
class DistributedHyperparamOptManager(HyperparamOptManager):
"""Manages distributed hyperparameter optimisation across many gpus."""
def __init__(
self,
param_ranges,
fixed_params,
root_model_folder,
worker_number,
search_iterations=1000,
num_iterations_per_worker=5,
clear_serialised_params=False,
):
"""Instantiates optimisation manager.
This hyperparameter optimisation pre-generates #search_iterations
hyperparameter combinations and serialises them
at the start. At runtime, each worker goes through their own set of
parameter ranges. The pregeneration
allows for multiple workers to run in parallel on different machines without
resulting in parameter overlaps.
Args:
param_ranges: Discrete hyperparameter range for random search.
fixed_params: Fixed model parameters per experiment.
root_model_folder: Folder to store optimisation artifacts.
worker_number: Worker index definining which set of hyperparameters to
test.
search_iterations: Maximum numer of random search iterations.
num_iterations_per_worker: How many iterations are handled per worker.
clear_serialised_params: Whether to regenerate hyperparameter
combinations.
"""
max_workers = int(np.ceil(search_iterations / num_iterations_per_worker))
# Sanity checks
if worker_number > max_workers:
raise ValueError(
"Worker number ({}) cannot be larger than the total number of workers!".format(max_workers)
)
if worker_number > search_iterations:
raise ValueError(
"Worker number ({}) cannot be larger than the max search iterations ({})!".format(
worker_number, search_iterations
)
)
print("*** Creating hyperparameter manager for worker {} ***".format(worker_number))
hyperparam_folder = os.path.join(root_model_folder, str(worker_number))
super().__init__(param_ranges, fixed_params, hyperparam_folder, override_w_fixed_params=True)
serialised_ranges_folder = os.path.join(root_model_folder, "hyperparams")
if clear_serialised_params:
print("Regenerating hyperparameter list")
if os.path.exists(serialised_ranges_folder):
shutil.rmtree(serialised_ranges_folder)
utils.create_folder_if_not_exist(serialised_ranges_folder)
self.serialised_ranges_path = os.path.join(serialised_ranges_folder, "ranges_{}.csv".format(search_iterations))
self.hyperparam_folder = hyperparam_folder # override
self.worker_num = worker_number
self.total_search_iterations = search_iterations
self.num_iterations_per_worker = num_iterations_per_worker
self.global_hyperparam_df = self.load_serialised_hyperparam_df()
self.worker_search_queue = self._get_worker_search_queue()
@property
def optimisation_completed(self):
return False if self.worker_search_queue else True
def get_next_parameters(self):
"""Returns next dictionary of hyperparameters to optimise."""
param_name = self.worker_search_queue.pop()
params = self.global_hyperparam_df.loc[param_name, :].to_dict()
# Always override!
for k in self.fixed_params:
print("Overriding saved {}: {}".format(k, self.fixed_params[k]))
params[k] = self.fixed_params[k]
return params
def load_serialised_hyperparam_df(self):
"""Loads serialsed hyperparameter ranges from file.
Returns:
DataFrame containing hyperparameter combinations.
"""
print(
"Loading params for {} search iterations form {}".format(
self.total_search_iterations, self.serialised_ranges_path
)
)
if os.path.exists(self.serialised_ranges_folder):
df = pd.read_csv(self.serialised_ranges_path, index_col=0)
else:
print("Unable to load - regenerating serach ranges instead")
df = self.update_serialised_hyperparam_df()
return df
def update_serialised_hyperparam_df(self):
"""Regenerates hyperparameter combinations and saves to file.
Returns:
DataFrame containing hyperparameter combinations.
"""
search_df = self._generate_full_hyperparam_df()
print(
"Serialising params for {} search iterations to {}".format(
self.total_search_iterations, self.serialised_ranges_path
)
)
search_df.to_csv(self.serialised_ranges_path)
return search_df
def _generate_full_hyperparam_df(self):
"""Generates actual hyperparameter combinations.
Returns:
DataFrame containing hyperparameter combinations.
"""
np.random.seed(131) # for reproducibility of hyperparam list
name_list = []
param_list = []
for _ in range(self.total_search_iterations):
params = super().get_next_parameters(name_list)
name = self._get_name(params)
name_list.append(name)
param_list.append(params)
full_search_df = pd.DataFrame(param_list, index=name_list)
return full_search_df
def clear(self): # reset when cleared
"""Clears results for hyperparameter manager and resets."""
super().clear()
self.worker_search_queue = self._get_worker_search_queue()
def load_results(self):
"""Load results from file and queue parameter combinations to try.
Returns:
Boolean indicating if results were successfully loaded.
"""
success = super().load_results()
if success:
self.worker_search_queue = self._get_worker_search_queue()
return success
def _get_worker_search_queue(self):
"""Generates the queue of param combinations for current worker.
Returns:
Queue of hyperparameter combinations outstanding.
"""
global_df = self.assign_worker_numbers(self.global_hyperparam_df)
worker_df = global_df[global_df["worker"] == self.worker_num]
left_overs = [s for s in worker_df.index if s not in self.results.columns]
return Deque(left_overs)
def assign_worker_numbers(self, df):
"""Updates parameter combinations with the index of the worker used.
Args:
df: DataFrame of parameter combinations.
Returns:
Updated DataFrame with worker number.
"""
output = df.copy()
n = self.total_search_iterations
batch_size = self.num_iterations_per_worker
max_worker_num = int(np.ceil(n / batch_size))
worker_idx = np.concatenate([np.tile(i + 1, self.num_iterations_per_worker) for i in range(max_worker_num)])
output["worker"] = worker_idx[: len(output)]
return output

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,224 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Generic helper functions used across codebase."""
import os
import pathlib
import numpy as np
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
# Generic.
def get_single_col_by_input_type(input_type, column_definition):
"""Returns name of single column.
Args:
input_type: Input type of column to extract
column_definition: Column definition list for experiment
"""
l = [tup[0] for tup in column_definition if tup[2] == input_type]
if len(l) != 1:
raise ValueError("Invalid number of columns for {}".format(input_type))
return l[0]
def extract_cols_from_data_type(data_type, column_definition, excluded_input_types):
"""Extracts the names of columns that correspond to a define data_type.
Args:
data_type: DataType of columns to extract.
column_definition: Column definition to use.
excluded_input_types: Set of input types to exclude
Returns:
List of names for columns with data type specified.
"""
return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types]
# Loss functions.
def tensorflow_quantile_loss(y, y_pred, quantile):
"""Computes quantile loss for tensorflow.
Standard quantile loss as defined in the "Training Procedure" section of
the main TFT paper
Args:
y: Targets
y_pred: Predictions
quantile: Quantile to use for loss calculations (between 0 & 1)
Returns:
Tensor for quantile loss.
"""
# Checks quantile
if quantile < 0 or quantile > 1:
raise ValueError("Illegal quantile value={}! Values should be between 0 and 1.".format(quantile))
prediction_underflow = y - y_pred
q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum(
-prediction_underflow, 0.0
)
return tf.reduce_sum(q_loss, axis=-1)
def numpy_normalised_quantile_loss(y, y_pred, quantile):
"""Computes normalised quantile loss for numpy arrays.
Uses the q-Risk metric as defined in the "Training Procedure" section of the
main TFT paper.
Args:
y: Targets
y_pred: Predictions
quantile: Quantile to use for loss calculations (between 0 & 1)
Returns:
Float for normalised quantile loss.
"""
prediction_underflow = y - y_pred
weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum(
-prediction_underflow, 0.0
)
quantile_loss = weighted_errors.mean()
normaliser = y.abs().mean()
return 2 * quantile_loss / normaliser
# OS related functions.
def create_folder_if_not_exist(directory):
"""Creates folder if it doesn't exist.
Args:
directory: Folder path to create.
"""
# Also creates directories recursively
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
# Tensorflow related functions.
def get_default_tensorflow_config(tf_device="gpu", gpu_id=0):
"""Creates tensorflow config for graphs to run on CPU or GPU.
Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi
GPU machines.
Args:
tf_device: 'cpu' or 'gpu'
gpu_id: GPU ID to use if relevant
Returns:
Tensorflow config.
"""
if tf_device == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # for training on cpu
tf_config = tf.ConfigProto(log_device_placement=False, device_count={"GPU": 0})
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
print("Selecting GPU ID={}".format(gpu_id))
tf_config = tf.ConfigProto(log_device_placement=False)
tf_config.gpu_options.allow_growth = True
return tf_config
def save(tf_session, model_folder, cp_name, scope=None):
"""Saves Tensorflow graph to checkpoint.
Saves all trainiable variables under a given variable scope to checkpoint.
Args:
tf_session: Session containing graph
model_folder: Folder to save models
cp_name: Name of Tensorflow checkpoint
scope: Variable scope containing variables to save
"""
# Save model
if scope is None:
saver = tf.train.Saver()
else:
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
save_path = saver.save(tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name)))
print("Model saved to: {0}".format(save_path))
def load(tf_session, model_folder, cp_name, scope=None, verbose=False):
"""Loads Tensorflow graph from checkpoint.
Args:
tf_session: Session to load graph into
model_folder: Folder containing serialised model
cp_name: Name of Tensorflow checkpoint
scope: Variable scope to use.
verbose: Whether to print additional debugging information.
"""
# Load model proper
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
print("Loading model from {0}".format(load_path))
print_weights_in_checkpoint(model_folder, cp_name)
initial_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
# Saver
if scope is None:
saver = tf.train.Saver()
else:
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
# Load
saver.restore(tf_session, load_path)
all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
if verbose:
print("Restored {0}".format(",".join(initial_vars.difference(all_vars))))
print("Existing {0}".format(",".join(all_vars.difference(initial_vars))))
print("All {0}".format(",".join(all_vars)))
print("Done.")
def print_weights_in_checkpoint(model_folder, cp_name):
"""Prints all weights in Tensorflow checkpoint.
Args:
model_folder: Folder containing checkpoint
cp_name: Name of checkpoint
Returns:
"""
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
print_tensors_in_checkpoint_file(file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True)

View File

@@ -0,0 +1,3 @@
tensorflow-gpu==1.15.0
numpy == 1.19.4
pandas==1.1.0

View File

@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
import data_formatters.base
import expt_settings.configs
import libs.hyperparam_opt
import libs.tft_model
import libs.utils as utils
import os
import datetime as dte
from qlib.model.base import ModelFT
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
# To register new datasets, please add them here.
ALLOW_DATASET = ["Alpha158"]
DATASET_SETTING = {
"Alpha158": {
"feature_col": ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "ROC60", "RESI10"],
"label_col": ["LABEL0"],
},
}
# To register new datasets, please add their configurations here.
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
def fill_test_na(test_df):
test_df_res = test_df.copy()
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
test_df_res.loc[:, feature_cols] = test_feature_fna
return test_df_res
def process_qlib_data(df, dataset, fillna=False):
"""Prepare data to fit the TFT model.
Args:
df: Original DataFrame.
fillna: Whether to fill the data with the mean values.
Returns:
Transformed DataFrame.
"""
# Several features selected manually
feature_col = DATASET_SETTING[dataset]["feature_col"]
label_col = DATASET_SETTING[dataset]["label_col"]
temp_df = df.loc[:, feature_col + label_col]
if fillna:
temp_df = fill_test_na(temp_df)
temp_df = temp_df.swaplevel()
temp_df = temp_df.sort_index()
temp_df = temp_df.reset_index(level=0)
dates = pd.to_datetime(temp_df.index)
temp_df["date"] = dates
temp_df["day_of_week"] = dates.dayofweek
temp_df["month"] = dates.month
temp_df["year"] = dates.year
temp_df["const"] = 1.0
return temp_df
def process_predicted(df, col_name):
"""Transform the TFT predicted data into Qlib format.
Args:
df: Original DataFrame.
fillna: New column name.
Returns:
Transformed DataFrame.
"""
df_res = df.copy()
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
df_res = df_res[[col_name]]
return df_res
def format_score(forecast_df, col_name="pred", label_shift=5):
pred = process_predicted(forecast_df, col_name=col_name)
pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
pred = pred.dropna()[col_name]
return pred
def transform_df(df, col_name="LABEL0"):
df_res = df["feature"]
df_res[col_name] = df["label"]
return df_res
class TFTModel(ModelFT):
"""TFT Model"""
def __init__(self, **kwargs):
self.model = None
def _prepare_data(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
return transform_df(df_train), transform_df(df_valid)
def fit(
self,
dataset: DatasetH,
DATASET="Alpha158",
MODEL_FOLDER="qlib_alpha158_model",
LABEL_COL="LABEL0",
LABEL_SHIFT=5,
USE_GPU_ID=0,
**kwargs
):
if DATASET not in ALLOW_DATASET:
raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
dtrain, dvalid = self._prepare_data(dataset)
dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
ExperimentConfig = expt_settings.configs.ExperimentConfig
config = ExperimentConfig(DATASET)
self.data_formatter = config.make_data_formatter()
self.model_folder = MODEL_FOLDER
self.gpu_id = USE_GPU_ID
self.label_shift = LABEL_SHIFT
self.expt_name = DATASET
self.label_col = LABEL_COL
use_gpu = (True, self.gpu_id)
# ===========================Training Process===========================
ModelClass = libs.tft_model.TemporalFusionTransformer
if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
raise ValueError(
"Data formatters should inherit from"
+ "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
)
default_keras_session = tf.keras.backend.get_session()
if use_gpu[0]:
self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
else:
self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
self.data_formatter.set_scalers(train)
# Sets up default params
fixed_params = self.data_formatter.get_experiment_params()
params = self.data_formatter.get_default_model_params()
# Wendi: 合并调优的参数和非调优的参数
params = {**params, **fixed_params}
if not os.path.exists(self.model_folder):
os.makedirs(self.model_folder)
params["model_folder"] = self.model_folder
print("*** Begin training ***")
best_loss = np.Inf
tf.reset_default_graph()
self.tf_graph = tf.Graph()
with self.tf_graph.as_default():
self.sess = tf.Session(config=self.tf_config)
tf.keras.backend.set_session(self.sess)
self.model = ModelClass(params, use_cudnn=use_gpu[0])
self.sess.run(tf.global_variables_initializer())
self.model.fit(train_df=train, valid_df=valid)
print("*** Finished training ***")
saved_model_dir = self.model_folder + "/" + "saved_model"
if not os.path.exists(saved_model_dir):
os.makedirs(saved_model_dir)
self.model.save(saved_model_dir)
def extract_numerical_data(data):
"""Strips out forecast time and identifier columns."""
return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
# p50_loss = utils.numpy_normalised_quantile_loss(
# extract_numerical_data(targets), extract_numerical_data(p50_forecast),
# 0.5)
# p90_loss = utils.numpy_normalised_quantile_loss(
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
# 0.9)
tf.keras.backend.set_session(default_keras_session)
print("Training completed.".format(dte.datetime.now()))
# ===========================Training Process===========================
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
d_test = dataset.prepare("test", col_set=["feature", "label"])
d_test = transform_df(d_test)
d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
use_gpu = (True, self.gpu_id)
# ===========================Predicting Process===========================
default_keras_session = tf.keras.backend.get_session()
# Sets up default params
fixed_params = self.data_formatter.get_experiment_params()
params = self.data_formatter.get_default_model_params()
params = {**params, **fixed_params}
print("*** Begin predicting ***")
tf.reset_default_graph()
with self.tf_graph.as_default():
tf.keras.backend.set_session(self.sess)
output_map = self.model.predict(test, return_targets=True)
targets = self.data_formatter.format_predictions(output_map["targets"])
p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
tf.keras.backend.set_session(default_keras_session)
predict50 = format_score(p50_forecast, "pred", 1)
predict90 = format_score(p90_forecast, "pred", 1)
predict = (predict50 + predict90) / 2 # self.label_shift
# ===========================Predicting Process===========================
return predict
def finetune(self, dataset: DatasetH):
"""
finetune model
Parameters
----------
dataset : DatasetH
dataset for finetuning
"""
pass

View File

@@ -0,0 +1,52 @@
sys:
rel_path: .
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TFTModel
module_path: tft
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -0,0 +1,3 @@
# XGBoost
* Code: [https://github.com/dmlc/xgboost](https://github.com/dmlc/xgboost)
* Paper: XGBoost: A Scalable Tree Boosting System. [https://dl.acm.org/doi/pdf/10.1145/2939672.2939785](https://dl.acm.org/doi/pdf/10.1145/2939672.2939785).

View File

@@ -0,0 +1,3 @@
numpy==1.17.4
pandas==1.1.2
xgboost==1.2.1

View File

@@ -0,0 +1,63 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: XGBModel
module_path: qlib.contrib.model.xgboost
kwargs:
eval_metric: rmse
colsample_bytree: 0.8879
eta: 0.0421
max_depth: 8
n_estimators: 647
subsample: 0.8789
nthread: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -1,222 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import json\n",
"import yaml\n",
"import pickle\n",
"from pathlib import Path\n",
"\n",
"import qlib\n",
"import pandas as pd\n",
"from qlib.config import REG_CN\n",
"from qlib.utils import exists_qlib_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CUR_DIR = Path.cwd()\n",
"MARKET = \"csi300\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# use default data\n",
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data\n",
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data(target_dir=provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with CUR_DIR.joinpath('estimator_config.yaml').open() as fp:\n",
" estimator_name = yaml.load(fp, Loader=yaml.FullLoader)['experiment']['name']\n",
"with CUR_DIR.joinpath(estimator_name, 'exp_info.json').open() as fp:\n",
" latest_id = json.load(fp)['id']\n",
" \n",
"estimator_dir = CUR_DIR.joinpath(estimator_name, 'sacred', latest_id)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# read estimator result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_df = pd.read_pickle(estimator_dir.joinpath('pred.pkl'))\n",
"report_normal_df = pd.read_pickle(estimator_dir.joinpath('report_normal.pkl'))\n",
"report_normal_df.index.names = ['index']\n",
"\n",
"analysis_df = pd.read_pickle(estimator_dir.joinpath('analysis.pkl'))\n",
"positions = pickle.load(estimator_dir.joinpath('positions.pkl').open('rb'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# analyze graphs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from qlib.data import D\n",
"from qlib.contrib.report import analysis_model, analysis_position\n",
"pred_df_dates = pred_df.index.get_level_values(level='datetime')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis position"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"stock_ret = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
"stock_ret.columns = ['label']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.report_graph(report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### risk analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"label_df = D.features(D.instruments(MARKET), ['Ref($close, -2)/Ref($close, -1) - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
"label_df.columns = ['label']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### score IC"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
"analysis_position.score_ic_graph(pred_label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### model performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_model.model_performance_graph(pred_label)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -1,53 +0,0 @@
experiment:
name: estimator_example
observer_type: file_storage
mode: train
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
args:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.0421
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
data:
class: Alpha158
args:
dropna_label: True
filter:
market: csi300
trainer:
class: StaticTrainer
args:
train_start_date: 2008-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-12-31
test_start_date: 2017-01-01
test_end_date: 2020-08-01
strategy:
class: TopkDropoutStrategy
args:
topk: 50
n_drop: 5
backtest:
normal_backtest_args:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: SH000300
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
qlib_data:
# when testing, please modify the following parameters according to the specific environment
provider_uri: "~/.qlib/qlib_data/cn_data"
region: "cn"

View File

@@ -1,55 +0,0 @@
experiment:
name: estimator_example
observer_type: file_storage
mode: train
model:
module_path: qlib.contrib.model.pytorch_nn
class: DNNModelPytorch
args:
loss: mse
input_dim: 158
output_dim: 1
lr: 0.002
lr_decay: 0.96
lr_decay_steps: 100
optimizer: 'adam'
max_steps: 8000
batch_size: 4096
GPU: '0'
data:
class: Alpha158
args:
dropna_label: True
dropna_feature: True
filter:
market: csi300
trainer:
class: StaticTrainer
args:
train_start_date: 2007-01-01
train_end_date: 2014-12-31
validate_start_date: 2015-01-01
validate_end_date: 2016-12-31
test_start_date: 2017-01-01
test_end_date: 2020-08-01
strategy:
class: TopkDropoutStrategy
args:
topk: 50
n_drop: 5
backtest:
normal_backtest_args:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: SH000300
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
qlib_data:
# when testing, please modify the following parameters according to the specific environment
provider_uri: "~/.qlib/qlib_data/cn_data"
region: "cn"

284
examples/run_all_model.py Normal file
View File

@@ -0,0 +1,284 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import sys
import fire
import time
import venv
import glob
import shutil
import signal
import inspect
import tempfile
import traceback
import functools
import statistics
import subprocess
from pathlib import Path
from operator import xor
from pprint import pprint
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.cli import workflow
from qlib.utils import exists_qlib_data
# init qlib
provider_uri = "~/.qlib/qlib_data/cn_data"
exp_folder_name = "run_all_model_records"
exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
exp_manager = {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {
"uri": "file:" + exp_path,
"default_exp_name": "Experiment",
},
}
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
if os.path.isdir(exp_path):
shutil.rmtree(exp_path)
# decorator to check the arguments
def only_allow_defined_args(function_to_decorate):
@functools.wraps(function_to_decorate)
def _return_wrapped(*args, **kwargs):
"""Internal wrapper function."""
argspec = inspect.getfullargspec(function_to_decorate)
valid_names = set(argspec.args + argspec.kwonlyargs)
if "self" in valid_names:
valid_names.remove("self")
for arg_name in kwargs:
if arg_name not in valid_names:
raise ValueError("Unknown argument seen '%s', expected: [%s]" % (arg_name, ", ".join(valid_names)))
return function_to_decorate(*args, **kwargs)
return _return_wrapped
# function to handle ctrl z and ctrl c
def handler(signum, frame):
os.system("kill -9 %d" % os.getpid())
signal.signal(signal.SIGTSTP, handler)
signal.signal(signal.SIGINT, handler)
# function to calculate the mean and std of a list in the results dictionary
def cal_mean_std(results) -> dict:
mean_std = dict()
for fn in results:
mean_std[fn] = dict()
for metric in results[fn]:
mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0]
std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0
mean_std[fn][metric] = [mean, std]
return mean_std
# function to create the environment ofr an anaconda environment
def create_env():
# create env
temp_dir = tempfile.mkdtemp()
env_path = Path(temp_dir).absolute()
sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
execute(f"conda create --prefix {env_path} python=3.7 -y")
python_path = env_path / "bin" / "python" # TODO: FIX ME!
sys.stderr.write("\n")
# get anaconda activate path
conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME!
return env_path, python_path, conda_activate
# function to execute the cmd
def execute(cmd):
with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
for line in p.stdout:
sys.stdout.write(line.split("\b")[0])
if "\b" in line:
sys.stdout.flush()
time.sleep(0.1)
sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
if p.returncode != 0:
return p.stderr
else:
return None
# function to get all the folders benchmark folder
def get_all_folders(models, exclude) -> dict:
folders = dict()
if isinstance(models, str):
model_list = models.split(",")
models = [m.lower().strip("[ ]") for m in model_list]
elif isinstance(models, list):
models = [m.lower() for m in models]
elif models is None:
models = [f.name.lower() for f in os.scandir("benchmarks")]
else:
raise ValueError("Input models type is not supported. Please provide str or list without space.")
for f in os.scandir("benchmarks"):
add = xor(bool(f.name.lower() in models), bool(exclude))
if add:
path = Path("benchmarks") / f.name
folders[f.name] = str(path.resolve())
return folders
# function to get all the files under the model folder
def get_all_files(folder_path) -> (str, str):
yaml_path = str(Path(f"{folder_path}") / "*.yaml")
req_path = str(Path(f"{folder_path}") / "*.txt")
return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
# function to retrieve all the results
def get_all_results(folders) -> dict:
results = dict()
for fn in folders:
exp = R.get_exp(experiment_name=fn, create=False)
recorders = exp.list_recorders()
result = dict()
result["annualized_return_with_cost"] = list()
result["information_ratio_with_cost"] = list()
result["max_drawdown_with_cost"] = list()
for recorder_id in recorders:
if recorders[recorder_id].status == "FINISHED":
recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
metrics = recorder.list_metrics()
result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
results[fn] = result
return results
# function to generate and save markdown table
def gen_and_save_md_table(metrics):
table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
table += "|---|---|---|---|\n"
for fn in metrics:
ar = metrics[fn]["annualized_return_with_cost"]
ir = metrics[fn]["information_ratio_with_cost"]
md = metrics[fn]["max_drawdown_with_cost"]
table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n"
pprint(table)
with open("table.md", "w") as f:
f.write(table)
return table
# function to run the all the models
@only_allow_defined_args
def run(times=1, models=None, exclude=False):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed.
Parameters:
-----------
times : int
determines how many times the model should be running.
models : str or list
determines the specific model or list of models to run or exclude.
exclude : boolean
determines whether the model being used is excluded or included.
Usage:
-------
Here are some use cases of the function in the bash:
.. code-block:: bash
# Case 1 - run all models multiple times
python run_all_model.py 3
# Case 2 - run specific models multiple times
python run_all_model.py 3 mlp
# Case 3 - run other models except those are given as arguments for multiple times
python run_all_model.py 3 [mlp,tft,lstm] True
# Case 4 - run specific models for one time
python run_all_model.py --models=[mlp,lightgbm]
# Case 5 - run other models except those are given as aruments for one time
python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
"""
# get all folders
folders = get_all_folders(models, exclude)
# init error messages:
errors = dict()
# run all the model for iterations
for fn in folders:
# create env by anaconda
env_path, python_path, conda_activate = create_env()
# get all files
sys.stderr.write("Retrieving files...\n")
yaml_path, req_path = get_all_files(folders[fn])
sys.stderr.write("\n")
# install requirements.txt
sys.stderr.write("Installing requirements.txt...\n")
execute(f"{python_path} -m pip install -r {req_path}")
sys.stderr.write("\n")
# setup gpu for tft
if fn == "TFT":
execute(
f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
)
sys.stderr.write("\n")
# install qlib
sys.stderr.write("Installing qlib...\n")
execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
if fn == "TFT":
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
) # TODO: FIX ME!
else:
execute(
f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
) # TODO: FIX ME!
sys.stderr.write("\n")
# run workflow_by_config for multiple times
for i in range(times):
sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
errs = execute(
f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
)
if errs is not None:
_errs = errors.get(fn, {})
_errs.update({i: errs})
errors[fn] = _errs
sys.stderr.write("\n")
# remove env
sys.stderr.write(f"Deleting the environment: {env_path}...\n")
shutil.rmtree(env_path)
# getting all results
sys.stderr.write(f"Retrieving results...\n")
results = get_all_results(folders)
# calculating the mean and std
sys.stderr.write(f"Calculating the mean and std of results...\n")
results = cal_mean_std(results)
# generating md table
sys.stderr.write(f"Generating markdown table...\n")
gen_and_save_md_table(results)
sys.stderr.write("\n")
# print erros
sys.stderr.write(f"Here are some of the errors of the models...\n")
pprint(errors)
sys.stderr.write("\n")
if __name__ == "__main__":
fire.Fire(run) # run all the model

View File

@@ -1,121 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.estimator.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri)
qlib.init(provider_uri=provider_uri, region=REG_CN)
MARKET = "CSI300"
BENCHMARK = "SH000300"
###################################
# train model
###################################
DATA_HANDLER_CONFIG = {
"dropna_label": True,
"start_date": "2008-01-01",
"end_date": "2020-08-01",
"market": MARKET,
}
TRAINER_CONFIG = {
"train_start_date": "2008-01-01",
"train_end_date": "2014-12-31",
"validate_start_date": "2015-01-01",
"validate_end_date": "2016-12-31",
"test_start_date": "2017-01-01",
"test_end_date": "2020-08-01",
}
# use default DataHandler
# custom DataHandler, refer to: TODO: DataHandler API url
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
**TRAINER_CONFIG
)
MODEL_CONFIG = {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
}
# use default model
# custom Model, refer to: TODO: Model API url
model = LGBModel(**MODEL_CONFIG)
model.fit(x_train, y_train, x_validate, y_validate)
_pred = model.predict(x_test)
_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
# backtest requires pred_score
pred_score = pd.DataFrame(index=_pred.index)
pred_score["score"] = _pred.iloc(axis=1)[0]
# save pred_score to file
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
pred_score.to_pickle(pred_score_path)
###################################
# backtest
###################################
STRATEGY_CONFIG = {
"topk": 50,
"n_drop": 5,
}
BACKTEST_CONFIG = {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": BENCHMARK,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
}
# use default strategy
# custom Strategy, refer to: TODO: Strategy API url
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
###################################
# analyze
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
###################################
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
print(analysis_df)

View File

@@ -1,338 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"\n",
"import qlib\n",
"import pandas as pd\n",
"from qlib.config import REG_CN\n",
"from qlib.contrib.model.gbdt import LGBModel\n",
"from qlib.contrib.estimator.handler import Alpha158\n",
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
"from qlib.contrib.evaluate import (\n",
" backtest as normal_backtest,\n",
" risk_analysis,\n",
")\n",
"from qlib.utils import exists_qlib_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# use default data\n",
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n",
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data(target_dir=provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"MARKET = \"csi300\"\n",
"BENCHMARK = \"SH000300\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# train model\n",
"###################################\n",
"DATA_HANDLER_CONFIG = {\n",
" \"dropna_label\": True,\n",
" \"start_date\": \"2008-01-01\",\n",
" \"end_date\": \"2020-08-01\",\n",
" \"market\": MARKET,\n",
"}\n",
"\n",
"TRAINER_CONFIG = {\n",
" \"train_start_date\": \"2008-01-01\",\n",
" \"train_end_date\": \"2014-12-31\",\n",
" \"validate_start_date\": \"2015-01-01\",\n",
" \"validate_end_date\": \"2016-12-31\",\n",
" \"test_start_date\": \"2017-01-01\",\n",
" \"test_end_date\": \"2020-08-01\",\n",
"}\n",
"\n",
"# use default DataHandler\n",
"# custom DataHandler, refer to: TODO: DataHandler api url\n",
"x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(**TRAINER_CONFIG)\n",
"\n",
"\n",
"MODEL_CONFIG = {\n",
" \"loss\": \"mse\",\n",
" \"colsample_bytree\": 0.8879,\n",
" \"learning_rate\": 0.0421,\n",
" \"subsample\": 0.8789,\n",
" \"lambda_l1\": 205.6999,\n",
" \"lambda_l2\": 580.9768,\n",
" \"max_depth\": 8,\n",
" \"num_leaves\": 210,\n",
" \"num_threads\": 20,\n",
"}\n",
"# use default model\n",
"# custom Model, refer to: TODO: Model api url\n",
"model = LGBModel(**MODEL_CONFIG)\n",
"model.fit(x_train, y_train, x_validate, y_validate)\n",
"_pred = model.predict(x_test)\n",
"_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)\n",
"\n",
"# backtest requires pred_score\n",
"pred_score = pd.DataFrame(index=_pred.index)\n",
"pred_score[\"score\"] = _pred.iloc(axis=1)[0]\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# backtest"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# backtest\n",
"###################################\n",
"STRATEGY_CONFIG = {\n",
" \"topk\": 50,\n",
" \"n_drop\": 5}\n",
"BACKTEST_CONFIG = {\n",
" \"verbose\": False,\n",
" \"limit_threshold\": 0.095,\n",
" \"account\": 100000000,\n",
" \"benchmark\": BENCHMARK,\n",
" \"deal_price\": \"close\",\n",
" \"open_cost\": 0.0005,\n",
" \"close_cost\": 0.0015,\n",
" \"min_cost\": 5,\n",
" \n",
"}\n",
"\n",
"# use default strategy\n",
"# custom Strategy, refer to: TODO: Strategy api url\n",
"strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)\n",
"report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# analyze"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# analyze\n",
"# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb\n",
"###################################\n",
"analysis = dict()\n",
"analysis[\"excess_return_without_cost\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"])\n",
"analysis[\"excess_return_with_cost\"] = risk_analysis(\n",
" report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"]\n",
")\n",
"analysis_df = pd.concat(analysis) # type: pd.DataFrame\n",
"print(analysis_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# analyze graphs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"pred_df_dates = pred_score.index.get_level_values(level='datetime')\n",
"report_normal_df = report_normal\n",
"positions = positions_normal\n",
"pred_df = pred_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis position"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"stock_ret = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
"stock_ret.columns = ['label']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.report_graph(report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### risk analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"label_df = D.features(D.instruments(MARKET), ['Ref($close, -2)/Ref($close, -1) - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
"label_df.columns = ['label']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### score IC"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
"analysis_position.score_ic_graph(pred_label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### model performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_model.model_performance_graph(pred_label)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,380 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Microsoft Corporation.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import sys, site\n",
"from pathlib import Path\n",
"\n",
"\n",
"try:\n",
" import qlib\n",
"except ImportError:\n",
" # install qlib\n",
" ! pip install pyqlib\n",
" # reload\n",
" site.main()\n",
"\n",
"scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n",
"if not scripts_dir.joinpath(\"get_data.py\").exists():\n",
" # download get_data.py script\n",
" scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
" scripts_dir.mkdir(parents=True, exist_ok=True)\n",
" import requests\n",
" with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
" with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
" fp.write(resp.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import qlib\n",
"import pandas as pd\n",
"from qlib.config import REG_CN\n",
"from qlib.contrib.model.gbdt import LGBModel\n",
"from qlib.contrib.data.handler import Alpha158\n",
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
"from qlib.contrib.evaluate import (\n",
" backtest as normal_backtest,\n",
" risk_analysis,\n",
")\n",
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
"from qlib.workflow import R\n",
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
"from qlib.utils import flatten_dict\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# use default data\n",
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
"if not exists_qlib_data(provider_uri):\n",
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(scripts_dir))\n",
" from get_data import GetData\n",
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"market = \"csi300\"\n",
"benchmark = \"SH000300\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# train model\n",
"###################################\n",
"data_handler_config = {\n",
" \"start_time\": \"2008-01-01\",\n",
" \"end_time\": \"2020-08-01\",\n",
" \"fit_start_time\": \"2008-01-01\",\n",
" \"fit_end_time\": \"2014-12-31\",\n",
" \"instruments\": market,\n",
"}\n",
"\n",
"task = {\n",
" \"model\": {\n",
" \"class\": \"LGBModel\",\n",
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
" \"kwargs\": {\n",
" \"loss\": \"mse\",\n",
" \"colsample_bytree\": 0.8879,\n",
" \"learning_rate\": 0.0421,\n",
" \"subsample\": 0.8789,\n",
" \"lambda_l1\": 205.6999,\n",
" \"lambda_l2\": 580.9768,\n",
" \"max_depth\": 8,\n",
" \"num_leaves\": 210,\n",
" \"num_threads\": 20,\n",
" },\n",
" },\n",
" \"dataset\": {\n",
" \"class\": \"DatasetH\",\n",
" \"module_path\": \"qlib.data.dataset\",\n",
" \"kwargs\": {\n",
" \"handler\": {\n",
" \"class\": \"Alpha158\",\n",
" \"module_path\": \"qlib.contrib.data.handler\",\n",
" \"kwargs\": data_handler_config,\n",
" },\n",
" \"segments\": {\n",
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
" },\n",
" },\n",
" },\n",
"}\n",
"\n",
"# model initiaiton\n",
"model = init_instance_by_config(task[\"model\"])\n",
"dataset = init_instance_by_config(task[\"dataset\"])\n",
"\n",
"# start exp to train model\n",
"with R.start(experiment_name=\"train_model\"):\n",
" R.log_params(**flatten_dict(task))\n",
" model.fit(dataset)\n",
" R.save_objects(trained_model=model)\n",
" rid = R.get_recorder().id\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# prediction, backtest & analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"###################################\n",
"# prediction, backtest & analysis\n",
"###################################\n",
"port_analysis_config = {\n",
" \"strategy\": {\n",
" \"class\": \"TopkDropoutStrategy\",\n",
" \"module_path\": \"qlib.contrib.strategy.strategy\",\n",
" \"kwargs\": {\n",
" \"topk\": 50,\n",
" \"n_drop\": 5,\n",
" },\n",
" },\n",
" \"backtest\": {\n",
" \"verbose\": False,\n",
" \"limit_threshold\": 0.095,\n",
" \"account\": 100000000,\n",
" \"benchmark\": benchmark,\n",
" \"deal_price\": \"close\",\n",
" \"open_cost\": 0.0005,\n",
" \"close_cost\": 0.0015,\n",
" \"min_cost\": 5,\n",
" },\n",
"}\n",
"\n",
"\n",
"# backtest and analysis\n",
"with R.start(experiment_name=\"backtest_analysis\"):\n",
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
" model = recorder.load_object(\"trained_model\")\n",
"\n",
" # prediction\n",
" recorder = R.get_recorder()\n",
" ba_rid = recorder.id\n",
" sr = SignalRecord(model, dataset, recorder)\n",
" sr.generate()\n",
"\n",
" # backtest & analysis\n",
" par = PortAnaRecord(recorder, port_analysis_config)\n",
" par.generate()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# analyze graphs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
"from qlib.data import D\n",
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
"pred_df = recorder.load_object(\"pred.pkl\")\n",
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",
"positions = recorder.load_object(\"portfolio_analysis/positions_normal.pkl\")\n",
"analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis.pkl\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis position"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### report"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.report_graph(report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### risk analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## analysis model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
"label_df.columns = ['label']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### score IC"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
"analysis_position.score_ic_graph(pred_label)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### model performance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_model.model_performance_graph(pred_label)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@@ -0,0 +1,120 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
port_analysis_config = {
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.strategy",
"kwargs": {
"topk": 50,
"n_drop": 5,
},
},
"backtest": {
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": benchmark,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
},
}
# model initiaiton
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
# start exp
with R.start(experiment_name="workflow"):
R.log_params(**flatten_dict(task))
model.fit(dataset)
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()

View File

@@ -2,19 +2,20 @@
# Licensed under the MIT License.
__version__ = "0.5.1.dev0"
__version__ = "0.6.0.alpha"
import os
import copy
import logging
import re
import subprocess
import platform
import sys
import copy
import yaml
import logging
import platform
import subprocess
from pathlib import Path
from .utils import can_use_cache
from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
from .workflow.utils import experiment_exit_handler
# init qlib
def init(default_conf="client", **kwargs):
@@ -22,6 +23,7 @@ def init(default_conf="client", **kwargs):
from .data.data import register_all_wrappers
from .log import get_module_logger, set_log_with_config
from .data.cache import H
from .workflow import R, QlibRecorder
C.reset()
H.clear()
@@ -34,17 +36,18 @@ def init(default_conf="client", **kwargs):
if _logging_config:
set_log_with_config(_logging_config)
# FIXME: this logger ignored the level in config
LOG = get_module_logger("Initialization", level=logging.INFO)
LOG.info(f"default_conf: {default_conf}.")
C.set_mode(default_conf)
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
for k, v in kwargs.items():
C[k] = v
if k not in C:
LOG.warning("Unrecognized config %s" % k)
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
C.resolve_path()
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
@@ -61,12 +64,10 @@ def init(default_conf="client", **kwargs):
if not os.path.exists(C["provider_uri"]):
if C["auto_mount"]:
LOG.error(
"Invalid provider uri: {}, please check if a valid provider uri has been set. This path does not exist.".format(
C["provider_uri"]
)
f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
)
else:
LOG.warning("auto_path is False, please make sure {} is mounted".format(C["mount_path"]))
LOG.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
elif C.get_uri_type() == QlibConfig.NFS_URI:
_mount_nfs_uri(C)
else:
@@ -80,6 +81,13 @@ def init(default_conf="client", **kwargs):
if "flask_server" in C:
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
# set up QlibRecorder
exp_manager = init_instance_by_config(C["exp_manager"])
qr = QlibRecorder(exp_manager)
R.register(qr)
# clean up experiment when python program ends
experiment_exit_handler()
def _mount_nfs_uri(C):
from .log import get_module_logger
@@ -94,9 +102,7 @@ def _mount_nfs_uri(C):
if not C["auto_mount"]:
if not os.path.exists(C["mount_path"]):
raise FileNotFoundError(
"Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format(
C["mount_path"], mount_command
)
f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
)
else:
# Judging system type
@@ -153,9 +159,7 @@ def _mount_nfs_uri(C):
os.makedirs(C["mount_path"], exist_ok=True)
except Exception:
raise OSError(
"Failed to create directory {}, please create {} manually!".format(
C["mount_path"], C["mount_path"]
)
f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
)
# check nfs-common
@@ -167,20 +171,18 @@ def _mount_nfs_uri(C):
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
"mount {} on {} error! Needs SUDO! Please mount manually: {}".format(
C["provider_uri"], C["mount_path"], mount_command
)
f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
)
elif command_status == 32512:
# LOG.error("Command error")
raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"]))
raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
elif command_status == 0:
LOG.info("Mount finished")
else:
LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path))
LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")
def init_from_yaml_conf(conf_path):
def init_from_yaml_conf(conf_path, **kwargs):
"""init_from_yaml_conf
:param conf_path: A path to the qlib config in yml format
@@ -188,5 +190,6 @@ def init_from_yaml_conf(conf_path):
with open(conf_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)

View File

@@ -14,6 +14,8 @@ Two modes are supported
import copy
from pathlib import Path
import re
import os
import multiprocessing
class Config:
@@ -62,6 +64,8 @@ class Config:
REG_CN = "cn"
REG_US = "us"
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
_default_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
@@ -78,7 +82,7 @@ _default_config = {
"calendar_cache": None,
# for simple dataset cache
"local_cache_path": None,
"kernels": 16,
"kernels": NUM_USABLE_CPU,
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
"default_disk_cache": 1, # 0:skip/1:use
@@ -124,6 +128,15 @@ _default_config = {
},
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
},
# Defatult config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
"kwargs": {
"uri": "file:" + str(Path(os.getcwd()).resolve() / "mlruns"),
"default_exp_name": "Experiment",
},
},
}
MODE_CONF = {
@@ -141,10 +154,11 @@ MODE_CONF = {
"redis_host": "127.0.0.1",
"redis_port": 6379,
"redis_task_db": 1,
"kernels": 64,
"kernels": NUM_USABLE_CPU,
# cache
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
"mount_path": None,
},
"client": {
# data provider config
@@ -162,7 +176,7 @@ MODE_CONF = {
"dataset_cache": "DiskDatasetCache",
"calendar_cache": None,
# client config
"kernels": 16,
"kernels": NUM_USABLE_CPU,
"mount_path": None,
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
# The nfs should be auto-mounted by qlib on other
@@ -212,7 +226,9 @@ class QlibConfig(Config):
def get_uri_type(self):
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
is_nfs_or_win = re.match("^[^/]+:.+", self["provider_uri"]) is not None # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
is_nfs_or_win = (
re.match("^[^/]+:.+", self["provider_uri"]) is not None
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI

View File

@@ -10,6 +10,7 @@ from ...data import D
from .account import Account
from ...config import C
from ...log import get_module_logger
from ...data.dataset.utils import get_level_index
LOG = get_module_logger("backtest")
@@ -18,7 +19,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
"""Parameters
----------
pred : pandas.DataFrame
predict should has <instrument, datetime> index and one `score` column
predict should has <datetime, instrument> index and one `score` column
Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
strategy : Strategy()
strategy part for backtest
trade_exchange : Exchange()
@@ -43,6 +45,12 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
`benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000905 CSI500
"""
# Convert format if the input format is not expected
if get_level_index(pred, level="datetime") == 1:
pred = pred.swaplevel().sort_index()
if isinstance(pred, pd.Series):
pred = pred.to_frame("score")
trade_account = Account(init_cash=account)
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
@@ -57,6 +65,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
get_date_by_shift(predict_dates[-1], shift=shift),
disk_cache=1,
)
if len(_temp_result) == 0:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
@@ -71,7 +81,7 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
# 1. Load the score_series at pred_date
try:
score = pred.loc(axis=0)[:, pred_date] # (stock_id, trade_date) multi_index, score in pdate
score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
score_series = score.reset_index(level="datetime", drop=True)[
"score"
] # pd.Series(index:stock_id, data: score)

View File

@@ -166,7 +166,7 @@ class Position:
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
cash = pd.Series()
cash = pd.Series(dtype=np.float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]

View File

@@ -0,0 +1,429 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_cls_kwargs
from ...data.dataset import processor as processor_module
from ...log import TimeInspector
from inspect import getfullargspec
import copy
def check_transform_proc(proc_l, fit_start_time, fit_end_time):
new_l = []
for p in proc_l:
if not isinstance(p, Processor):
klass, pkwargs = get_cls_kwargs(p, processor_module)
args = getfullargspec(klass).args
if "fit_start_time" in args and "fit_end_time" in args:
assert (
fit_start_time is not None and fit_end_time is not None
), "Make sure `fit_start_time` and `fit_end_time` are not None."
pkwargs.update(
{
"fit_start_time": fit_start_time,
"fit_end_time": fit_end_time,
}
)
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
else:
new_l.append(p)
return new_l
class ALPHA360_Denoise(DataHandlerLP):
def __init__(self, instruments="csi500", start_time=None, end_time=None, fit_start_time=None, fit_end_time=None):
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
"feature": self.get_feature_config(),
"label": self.get_label_config(),
},
},
}
learn_processors = [
{"class": "DropnaLabel", "kwargs": {"fields_group": "label"}},
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
]
infer_processors = [
{"class": "ProcessInf", "kwargs": {}},
{"class": "TanhProcess", "kwargs": {}},
{"class": "Fillna", "kwargs": {}},
]
super().__init__(
instruments,
start_time,
end_time,
data_loader=data_loader,
learn_processors=learn_processors,
infer_processors=infer_processors,
)
def get_label_config(self):
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
def get_feature_config(self):
fields = []
names = []
for i in range(59, 0, -1):
fields += ["Ref($close, %d)/$close" % (i)]
names += ["CLOSE%d" % (i)]
fields += ["$close/$close"]
names += ["CLOSE0"]
for i in range(59, 0, -1):
fields += ["Ref($open, %d)/$close" % (i)]
names += ["OPEN%d" % (i)]
fields += ["$open/$close"]
names += ["OPEN0"]
for i in range(59, 0, -1):
fields += ["Ref($high, %d)/$close" % (i)]
names += ["HIGH%d" % (i)]
fields += ["$high/$close"]
names += ["HIGH0"]
for i in range(59, 0, -1):
fields += ["Ref($low, %d)/$close" % (i)]
names += ["LOW%d" % (i)]
fields += ["$low/$close"]
names += ["LOW0"]
for i in range(59, 0, -1):
fields += ["Ref($vwap, %d)/$close" % (i)]
names += ["VWAP%d" % (i)]
fields += ["$vwap/$close"]
names += ["VWAP0"]
for i in range(59, 0, -1):
fields += ["Ref($volume, %d)/$volume" % (i)]
names += ["VOLUME%d" % (i)]
fields += ["$volume/$volume"]
names += ["VOLUME0"]
return fields, names
_DEFAULT_LEARN_PROCESSORS = [
{"class": "DropnaLabel"},
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
]
_DEFAULT_INFER_PROCESSORS = [
{"class": "ProcessInf", "kwargs": {}},
{"class": "ZScoreNorm", "kwargs": {}},
{"class": "Fillna", "kwargs": {}},
]
class ALPHA360(DataHandlerLP):
def __init__(
self,
instruments="csi500",
start_time=None,
end_time=None,
infer_processors=_DEFAULT_INFER_PROCESSORS,
learn_processors=_DEFAULT_LEARN_PROCESSORS,
fit_start_time=None,
fit_end_time=None,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
"feature": self.get_feature_config(),
"label": kwargs.get("label", self.get_label_config()),
},
},
}
super().__init__(
instruments,
start_time,
end_time,
data_loader=data_loader,
learn_processors=learn_processors,
infer_processors=infer_processors,
)
def get_label_config(self):
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
def get_feature_config(self):
fields = []
names = []
for i in range(59, 0, -1):
fields += ["Ref($close, %d)/$close" % (i)]
names += ["CLOSE%d" % (i)]
fields += ["$close/$close"]
names += ["CLOSE0"]
for i in range(59, 0, -1):
fields += ["Ref($open, %d)/$close" % (i)]
names += ["OPEN%d" % (i)]
fields += ["$open/$close"]
names += ["OPEN0"]
for i in range(59, 0, -1):
fields += ["Ref($high, %d)/$close" % (i)]
names += ["HIGH%d" % (i)]
fields += ["$high/$close"]
names += ["HIGH0"]
for i in range(59, 0, -1):
fields += ["Ref($low, %d)/$close" % (i)]
names += ["LOW%d" % (i)]
fields += ["$low/$close"]
names += ["LOW0"]
for i in range(59, 0, -1):
fields += ["Ref($vwap, %d)/$close" % (i)]
names += ["VWAP%d" % (i)]
fields += ["$vwap/$close"]
names += ["VWAP0"]
for i in range(59, 0, -1):
fields += ["Ref($volume, %d)/$volume" % (i)]
names += ["VOLUME%d" % (i)]
fields += ["$volume/$volume"]
names += ["VOLUME0"]
return fields, names
class ALPHA360vwap(ALPHA360):
def get_label_config(self):
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
class Alpha158(DataHandlerLP):
def __init__(
self,
instruments="csi500",
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=_DEFAULT_LEARN_PROCESSORS,
fit_start_time=None,
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A,
**kwargs,
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {"feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config())},
},
}
super().__init__(
instruments,
start_time,
end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
process_type=process_type,
)
def get_feature_config(self):
conf = {
"kbar": {},
"price": {
"windows": [0],
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
},
"rolling": {},
}
return self.parse_config_to_fields(conf)
def get_label_config(self):
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
@staticmethod
def parse_config_to_fields(config):
"""create factors from config
config = {
'kbar': {}, # whether to use some hard-code kbar features
'price': { # whether to use raw price features
'windows': [0, 1, 2, 3, 4], # use price at n days ago
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
},
'volume': { # whether to use raw volume features
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
},
'rolling': { # whether to use rolling operator based features
'windows': [5, 10, 20, 30, 60], # rolling windows size
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
#if include is None we will use default operators
'exclude': ['RANK'], # rolling operator not to use
}
}
"""
fields = []
names = []
if "kbar" in config:
fields += [
"($close-$open)/$open",
"($high-$low)/$open",
"($close-$open)/($high-$low+1e-12)",
"($high-Greater($open, $close))/$open",
"($high-Greater($open, $close))/($high-$low+1e-12)",
"(Less($open, $close)-$low)/$open",
"(Less($open, $close)-$low)/($high-$low+1e-12)",
"(2*$close-$high-$low)/$open",
"(2*$close-$high-$low)/($high-$low+1e-12)",
]
names += [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
]
if "price" in config:
windows = config["price"].get("windows", range(5))
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
for field in feature:
field = field.lower()
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
names += [field.upper() + str(d) for d in windows]
if "volume" in config:
windows = config["volume"].get("windows", range(5))
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
names += ["VOLUME" + str(d) for d in windows]
if "rolling" in config:
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
include = config["rolling"].get("include", None)
exclude = config["rolling"].get("exclude", [])
# `exclude` in dataset config unnecessary filed
# `include` in dataset config necessary field
use = lambda x: x not in exclude and (include is None or x in include)
if use("ROC"):
fields += ["Ref($close, %d)/$close" % d for d in windows]
names += ["ROC%d" % d for d in windows]
if use("MA"):
fields += ["Mean($close, %d)/$close" % d for d in windows]
names += ["MA%d" % d for d in windows]
if use("STD"):
fields += ["Std($close, %d)/$close" % d for d in windows]
names += ["STD%d" % d for d in windows]
if use("BETA"):
fields += ["Slope($close, %d)/$close" % d for d in windows]
names += ["BETA%d" % d for d in windows]
if use("RSQR"):
fields += ["Rsquare($close, %d)" % d for d in windows]
names += ["RSQR%d" % d for d in windows]
if use("RESI"):
fields += ["Resi($close, %d)/$close" % d for d in windows]
names += ["RESI%d" % d for d in windows]
if use("MAX"):
fields += ["Max($high, %d)/$close" % d for d in windows]
names += ["MAX%d" % d for d in windows]
if use("LOW"):
fields += ["Min($low, %d)/$close" % d for d in windows]
names += ["MIN%d" % d for d in windows]
if use("QTLU"):
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
names += ["QTLU%d" % d for d in windows]
if use("QTLD"):
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
names += ["QTLD%d" % d for d in windows]
if use("RANK"):
fields += ["Rank($close, %d)" % d for d in windows]
names += ["RANK%d" % d for d in windows]
if use("RSV"):
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
names += ["RSV%d" % d for d in windows]
if use("IMAX"):
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
names += ["IMAX%d" % d for d in windows]
if use("IMIN"):
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
names += ["IMIN%d" % d for d in windows]
if use("IMXD"):
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
names += ["IMXD%d" % d for d in windows]
if use("CORR"):
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
names += ["CORR%d" % d for d in windows]
if use("CORD"):
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
names += ["CORD%d" % d for d in windows]
if use("CNTP"):
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
names += ["CNTP%d" % d for d in windows]
if use("CNTN"):
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
names += ["CNTN%d" % d for d in windows]
if use("CNTD"):
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
names += ["CNTD%d" % d for d in windows]
if use("SUMP"):
fields += [
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMP%d" % d for d in windows]
if use("SUMN"):
fields += [
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMN%d" % d for d in windows]
if use("SUMD"):
fields += [
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["SUMD%d" % d for d in windows]
if use("VMA"):
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VMA%d" % d for d in windows]
if use("VSTD"):
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VSTD%d" % d for d in windows]
if use("WVMA"):
fields += [
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["WVMA%d" % d for d in windows]
if use("VSUMP"):
fields += [
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMP%d" % d for d in windows]
if use("VSUMN"):
fields += [
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["VSUMN%d" % d for d in windows]
if use("VSUMD"):
fields += [
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["VSUMD%d" % d for d in windows]
return fields, names
class Alpha158vwap(Alpha158):
def get_label_config(self):
return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])

View File

@@ -0,0 +1,118 @@
import numpy as np
import pandas as pd
import copy
from ...log import TimeInspector
from ...utils.serial import Serializable
from ...data.dataset.processor import Processor, get_group_columns
class ConfigSectionProcessor(Processor):
"""
This processor is designed for Alpha158. And will be replaced by simple processors in the future
"""
def __init__(self, fields_group=None, **kwargs):
super().__init__()
# Options
self.fillna_feature = kwargs.get("fillna_feature", True)
self.fillna_label = kwargs.get("fillna_label", True)
self.clip_feature_outlier = kwargs.get("clip_feature_outlier", False)
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
self.fields_group = None
def __call__(self, df):
return self._transform(df)
def _transform(self, df):
def _label_norm(x):
x = x - x.mean() # copy
x /= x.std()
if self.clip_label_outlier:
x.clip(-3, 3, inplace=True)
if self.fillna_label:
x.fillna(0, inplace=True)
return x
def _feature_norm(x):
x = x - x.median() # copy
x /= x.abs().median() * 1.4826
if self.clip_feature_outlier:
x.clip(-3, 3, inplace=True)
if self.shrink_feature_outlier:
x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)
x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
if self.fillna_feature:
x.fillna(0, inplace=True)
return x
TimeInspector.set_time_mark()
# Copy the focus part and change it to single level
selected_cols = get_group_columns(df, self.fields_group)
df_focus = df[selected_cols].copy()
if len(df_focus.columns.levels) > 1:
df_focus = df_focus.droplevel(level=0)
# Label
cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")]
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm)
# Features
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
_cols = [
"KMID",
"KSFT",
"OPEN",
"HIGH",
"LOW",
"CLOSE",
"VWAP",
"ROC",
"MA",
"BETA",
"RESI",
"QTLU",
"QTLD",
"RSV",
"SUMP",
"SUMN",
"SUMD",
"VSUMP",
"VSUMN",
"VSUMD",
]
pat = "|".join(["^" + x for x in _cols])
cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))]
df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm)
cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")]
df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")]
df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")]
df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")]
df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")]
df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
df[selected_cols] = df_focus.values
TimeInspector.log_cost_time("Finished preprocessing data.")
return df

View File

@@ -1,176 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import yaml
import copy
import os
import json
import tempfile
from pathlib import Path
from ...config import REG_CN
class EstimatorConfigManager(object):
def __init__(self, config_path):
if not config_path:
raise ValueError("Config path is invalid.")
self.config_path = config_path
with open(config_path) as fp:
config = yaml.load(fp, Loader=yaml.FullLoader)
self.config = copy.deepcopy(config)
self.ex_config = ExperimentConfig(config.get("experiment", dict()), self)
self.data_config = DataConfig(config.get("data", dict()), self)
self.model_config = ModelConfig(config.get("model", dict()), self)
self.trainer_config = TrainerConfig(config.get("trainer", dict()), self)
self.strategy_config = StrategyConfig(config.get("strategy", dict()), self)
self.backtest_config = BacktestConfig(config.get("backtest", dict()), self)
self.qlib_data_config = QlibDataConfig(config.get("qlib_data", dict()), self)
# If the start_date and end_date are not given in data_config, they will be referred from the trainer_config.
handler_start_date = self.data_config.handler_parameters.get("start_date", None)
handler_end_date = self.data_config.handler_parameters.get("end_date", None)
if handler_start_date is None:
self.data_config.handler_parameters["start_date"] = self.trainer_config.parameters["train_start_date"]
if handler_end_date is None:
self.data_config.handler_parameters["end_date"] = self.trainer_config.parameters["test_end_date"]
class ExperimentConfig(object):
TRAIN_MODE = "train"
TEST_MODE = "test"
OBSERVER_FILE_STORAGE = "file_storage"
OBSERVER_MONGO = "mongo"
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for experiment
:param CONFIG_MANAGER: The estimator config manager
"""
self.name = config.get("name", "test_experiment")
# The dir of the result of all the experiments
self.global_dir = config.get("dir", os.path.dirname(CONFIG_MANAGER.config_path))
# The dir of the result of current experiment
self.ex_dir = os.path.join(self.global_dir, self.name)
if not os.path.exists(self.ex_dir):
os.makedirs(self.ex_dir)
self.tmp_run_dir = tempfile.mkdtemp(dir=self.ex_dir)
self.mode = config.get("mode", ExperimentConfig.TRAIN_MODE)
self.sacred_dir = os.path.join(self.ex_dir, "sacred")
self.observer_type = config.get("observer_type", ExperimentConfig.OBSERVER_FILE_STORAGE)
self.mongo_url = config.get("mongo_url", None)
self.db_name = config.get("db_name", None)
self.finetune = config.get("finetune", False)
# The path of the experiment id of the experiment
self.exp_info_path = config.get("exp_info_path", os.path.join(self.ex_dir, "exp_info.json"))
exp_info_dir = Path(self.exp_info_path).parent
exp_info_dir.mkdir(parents=True, exist_ok=True)
# Test mode config
loader_args = config.get("loader", dict())
if self.mode == ExperimentConfig.TEST_MODE or self.finetune:
loader_exp_info_path = loader_args.get("exp_info_path", None)
self.loader_model_index = loader_args.get("model_index", None)
if (loader_exp_info_path is not None) and (os.path.exists(loader_exp_info_path)):
with open(loader_exp_info_path) as fp:
loader_dict = json.load(fp)
for k, v in loader_dict.items():
setattr(self, "loader_{}".format(k), v)
# Check loader experiment id
assert hasattr(self, "loader_id"), "If mode is test or finetune is True, loader must contain id."
else:
self.loader_id = loader_args.get("id", None)
if self.loader_id is None:
raise ValueError("If mode is test or finetune is True, loader must contain id.")
self.loader_observer_type = loader_args.get("observer_type", self.observer_type)
self.loader_name = loader_args.get("name", self.name)
self.loader_dir = loader_args.get("dir", self.global_dir)
self.loader_mongo_url = loader_args.get("mongo_url", self.mongo_url)
self.loader_db_name = loader_args.get("db_name", self.db_name)
class DataConfig(object):
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for data
:param CONFIG_MANAGER: The estimator config manager
"""
self.handler_module_path = config.get("module_path", "qlib.contrib.estimator.handler")
self.handler_class = config.get("class", "ALPHA360")
self.handler_parameters = config.get("args", dict())
self.handler_filter = config.get("filter", dict())
# Update provider uri.
class ModelConfig(object):
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for model
:param CONFIG_MANAGER: The estimator config manager
"""
self.model_class = config.get("class", "Model")
self.model_module_path = config.get("module_path", "qlib.contrib.model")
self.save_dir = os.path.join(CONFIG_MANAGER.ex_config.tmp_run_dir, "model")
self.save_path = config.get("save_path", os.path.join(self.save_dir, "model.bin"))
self.parameters = config.get("args", dict())
# Make dir if need.
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
class TrainerConfig(object):
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for trainer
:param CONFIG_MANAGER: The estimator config manager
"""
self.trainer_class = config.get("class", "StaticTrainer")
self.trainer_module_path = config.get("module_path", "qlib.contrib.estimator.trainer")
self.parameters = config.get("args", dict())
class StrategyConfig(object):
def __init__(self, config, CONFIG_MANAGER):
"""__init__
:param config: The config dict for strategy
:param CONFIG_MANAGER: The estimator config manager
"""
self.strategy_class = config.get("class", "TopkDropoutStrategy")
self.strategy_module_path = config.get("module_path", "qlib.contrib.strategy.strategy")
self.parameters = config.get("args", dict())
class BacktestConfig(object):
def __init__(self, config, CONFIG_MANAGE):
"""__init__
:param config: The config dict for strategy
:param CONFIG_MANAGE: The estimator config manager
"""
self.normal_backtest_parameters = config.get("normal_backtest_args", dict())
self.long_short_backtest_parameters = config.get("long_short_backtest_args", dict())
class QlibDataConfig(object):
def __init__(self, config, CONFIG_MANAGE):
"""__init__
:param config: The config dict for qlib_client
:param CONFIG_MANAGE: The estimator config manager
"""
self.provider_uri = config.pop("provider_uri", "~/.qlib/qlib_data/cn_data")
self.auto_mount = config.pop("auto_mount", False)
self.mount_path = config.pop("mount_path", "~/.qlib/qlib_data/cn_data")
self.region = config.pop("region", REG_CN)
self.args = config

View File

@@ -1,328 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
import pandas as pd
import os
import copy
import json
import yaml
import pickle
import qlib
from ..evaluate import risk_analysis
from ..evaluate import backtest as normal_backtest
from ..evaluate import long_short_backtest
from .config import ExperimentConfig
from .fetcher import create_fetcher_with_config
from ...log import get_module_logger, TimeInspector
from ...utils import get_module_by_module_path, compare_dict_value
class Estimator(object):
def __init__(self, config_manager, sacred_ex):
# Set logger.
self.logger = get_module_logger("Estimator")
# 1. Set config manager.
self.config_manager = config_manager
# 2. Set configs.
self.ex_config = config_manager.ex_config
self.data_config = config_manager.data_config
self.model_config = config_manager.model_config
self.trainer_config = config_manager.trainer_config
self.strategy_config = config_manager.strategy_config
self.backtest_config = config_manager.backtest_config
# If experiment.mode is test or experiment.finetune is True, load the experimental results in the loader
if self.ex_config.mode == self.ex_config.TEST_MODE or self.ex_config.finetune:
self.compare_config_with_config_manger(self.config_manager)
# 3. Set sacred_experiment.
self.ex = sacred_ex
# 4. Init data handler.
self.data_handler = None
self._init_data_handler()
# 5. Init trainer.
self.trainer = None
self._init_trainer()
# 6. Init strategy.
self.strategy = None
self._init_strategy()
def _init_data_handler(self):
handler_module = get_module_by_module_path(self.data_config.handler_module_path)
# Set market
market = self.data_config.handler_filter.get("market", None)
if market is None:
if "market" in self.data_config.handler_parameters:
self.logger.warning(
"Warning: The market in data.args section is deprecated. "
"It only works when market is not set in data.filter section. "
"It will be overridden by market in the data.filter section."
)
market = self.data_config.handler_parameters["market"]
else:
market = "csi500"
self.data_config.handler_parameters["market"] = market
data_filter_list = []
handler_filters = self.data_config.handler_filter.get("filter_pipeline", list())
for h_filter in handler_filters:
filter_module_path = h_filter.get("module_path", "qlib.data.filter")
filter_class_name = h_filter.get("class", "")
filter_parameters = h_filter.get("args", {})
filter_module = get_module_by_module_path(filter_module_path)
filter_class = getattr(filter_module, filter_class_name)
data_filter = filter_class(**filter_parameters)
data_filter_list.append(data_filter)
self.data_config.handler_parameters["data_filter_list"] = data_filter_list
handler_class = getattr(handler_module, self.data_config.handler_class)
self.data_handler = handler_class(**self.data_config.handler_parameters)
def _init_trainer(self):
model_module = get_module_by_module_path(self.model_config.model_module_path)
trainer_module = get_module_by_module_path(self.trainer_config.trainer_module_path)
model_class = getattr(model_module, self.model_config.model_class)
trainer_class = getattr(trainer_module, self.trainer_config.trainer_class)
self.trainer = trainer_class(
model_class,
self.model_config.save_path,
self.model_config.parameters,
self.data_handler,
self.ex,
**self.trainer_config.parameters
)
def _init_strategy(self):
module = get_module_by_module_path(self.strategy_config.strategy_module_path)
strategy_class = getattr(module, self.strategy_config.strategy_class)
self.strategy = strategy_class(**self.strategy_config.parameters)
def run(self):
if self.ex_config.mode == ExperimentConfig.TRAIN_MODE:
self.trainer.train()
elif self.ex_config.mode == ExperimentConfig.TEST_MODE:
self.trainer.load()
else:
raise ValueError("unexpected mode: %s" % self.ex_config.mode)
analysis = self.backtest()
print(analysis)
self.logger.info(
"experiment id: {}, experiment name: {}".format(self.ex.experiment.current_run._id, self.ex_config.name)
)
# Remove temp dir
# shutil.rmtree(self.ex_config.tmp_run_dir)
def backtest(self):
TimeInspector.set_time_mark()
# 1. Get pred and prediction score of model(s).
pred = self.trainer.get_test_score()
try:
performance = self.trainer.get_test_performance()
except NotImplementedError:
performance = None
# 2. Normal Backtest.
report_normal, positions_normal = self._normal_backtest(pred)
# 3. Long-Short Backtest.
# Deprecated
# long_short_reports = self._long_short_backtest(pred)
# 4. Analyze
analysis_df = self._analyze(report_normal)
# 5. Save.
self._save_backtest_result(
pred,
analysis_df,
positions_normal,
report_normal,
# long_short_reports,
performance,
)
return analysis_df
def _normal_backtest(self, pred):
TimeInspector.set_time_mark()
if "account" not in self.backtest_config.normal_backtest_parameters:
if "account" in self.strategy_config.parameters:
self.logger.warning(
"Warning: The account in strategy section is deprecated. "
"It only works when account is not set in backtest section. "
"It will be overridden by account in the backtest section."
)
self.backtest_config.normal_backtest_parameters["account"] = self.strategy_config.parameters["account"]
report_normal, positions_normal = normal_backtest(
pred, strategy=self.strategy, **self.backtest_config.normal_backtest_parameters
)
TimeInspector.log_cost_time("Finished normal backtest.")
return report_normal, positions_normal
def _long_short_backtest(self, pred):
TimeInspector.set_time_mark()
long_short_reports = long_short_backtest(pred, **self.backtest_config.long_short_backtest_parameters)
TimeInspector.log_cost_time("Finished long-short backtest.")
return long_short_reports
@staticmethod
def _analyze(report_normal):
TimeInspector.set_time_mark()
analysis = dict()
# analysis["pred_long"] = risk_analysis(long_short_reports["long"])
# analysis["pred_short"] = risk_analysis(long_short_reports["short"])
# analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"])
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
TimeInspector.log_cost_time(
"Finished generating analysis," " average turnover is: {0:.4f}.".format(report_normal["turnover"].mean())
)
return analysis_df
def _save_backtest_result(self, pred, analysis, positions, report_normal, performance):
# 1. Result dir.
result_dir = os.path.join(self.config_manager.ex_config.tmp_run_dir, "result")
if not os.path.exists(result_dir):
os.makedirs(result_dir)
self.ex.add_info(
"task_config",
json.loads(json.dumps(self.config_manager.config, default=str)),
)
# 2. Pred.
TimeInspector.set_time_mark()
pred_pkl_path = os.path.join(result_dir, "pred.pkl")
pred.to_pickle(pred_pkl_path)
self.ex.add_artifact(pred_pkl_path)
TimeInspector.log_cost_time("Finished saving pred.pkl to: {}".format(pred_pkl_path))
# 3. Ana.
TimeInspector.set_time_mark()
analysis_pkl_path = os.path.join(result_dir, "analysis.pkl")
analysis.to_pickle(analysis_pkl_path)
self.ex.add_artifact(analysis_pkl_path)
TimeInspector.log_cost_time("Finished saving analysis.pkl to: {}".format(analysis_pkl_path))
# 4. Pos.
TimeInspector.set_time_mark()
positions_pkl_path = os.path.join(result_dir, "positions.pkl")
with open(positions_pkl_path, "wb") as fp:
pickle.dump(positions, fp)
self.ex.add_artifact(positions_pkl_path)
TimeInspector.log_cost_time("Finished saving positions.pkl to: {}".format(positions_pkl_path))
# 5. Report normal.
TimeInspector.set_time_mark()
report_normal_pkl_path = os.path.join(result_dir, "report_normal.pkl")
report_normal.to_pickle(report_normal_pkl_path)
self.ex.add_artifact(report_normal_pkl_path)
TimeInspector.log_cost_time("Finished saving report_normal.pkl to: {}".format(report_normal_pkl_path))
# 6. Report long short.
# Deprecated
# for k, name in zip(
# ["long", "short", "long_short"],
# ["report_long.pkl", "report_short.pkl", "report_long_short.pkl"],
# ):
# TimeInspector.set_time_mark()
# pkl_path = os.path.join(result_dir, name)
# long_short_reports[k].to_pickle(pkl_path)
# self.ex.add_artifact(pkl_path)
# TimeInspector.log_cost_time("Finished saving {} to: {}".format(name, pkl_path))
# 7. Origin test label.
TimeInspector.set_time_mark()
label_pkl_path = os.path.join(result_dir, "label.pkl")
self.data_handler.get_origin_test_label_with_date(
self.trainer_config.parameters["test_start_date"],
self.trainer_config.parameters["test_end_date"],
).to_pickle(label_pkl_path)
self.ex.add_artifact(label_pkl_path)
TimeInspector.log_cost_time("Finished saving label.pkl to: {}".format(label_pkl_path))
# 8. Experiment info, save the model(s) performance here.
TimeInspector.set_time_mark()
cur_ex_id = self.ex.experiment.current_run._id
exp_info = {
"id": cur_ex_id,
"name": self.ex_config.name,
"performance": performance,
"observer_type": self.ex_config.observer_type,
}
if self.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
exp_info.update(
{
"mongo_url": self.ex_config.mongo_url,
"db_name": self.ex_config.db_name,
}
)
else:
exp_info.update({"dir": self.ex_config.global_dir})
with open(self.ex_config.exp_info_path, "w") as fp:
json.dump(exp_info, fp, indent=4, sort_keys=True)
self.ex.add_artifact(self.ex_config.exp_info_path)
TimeInspector.log_cost_time("Finished saving ex_info to: {}".format(self.ex_config.exp_info_path))
@staticmethod
def compare_config_with_config_manger(config_manager):
"""Compare loader model args and current config with ConfigManage
:param config_manager: ConfigManager
:return:
"""
fetcher = create_fetcher_with_config(config_manager, load_form_loader=True)
loader_mode_config = fetcher.get_experiment(
exp_name=config_manager.ex_config.loader_name,
exp_id=config_manager.ex_config.loader_id,
fields=["task_config"],
)["task_config"]
with open(config_manager.config_path) as fp:
current_config = yaml.load(fp.read())
current_config = json.loads(json.dumps(current_config, default=str))
logger = get_module_logger("Estimator")
loader_mode_config = copy.deepcopy(loader_mode_config)
current_config = copy.deepcopy(current_config)
# Require test_mode_config.test_start_date <= current_config.test_start_date
loader_trainer_args = loader_mode_config.get("trainer", {}).get("args", {})
cur_trainer_args = current_config.get("trainer", {}).get("args", {})
loader_start_date = loader_trainer_args.pop("test_start_date")
cur_test_start_date = cur_trainer_args.pop("test_start_date")
assert (
loader_start_date <= cur_test_start_date
), "Require: loader_mode_config.test_start_date <= current_config.test_start_date"
# TODO: For the user's own extended `Trainer`, the support is not very good
if "RollingTrainer" == current_config.get("trainer", {}).get("class", None):
loader_period = loader_trainer_args.pop("rolling_period")
cur_period = cur_trainer_args.pop("rolling_period")
assert (
loader_period == cur_period
), "Require: loader_mode_config.rolling_period == current_config.rolling_period"
compare_section = ["trainer", "model", "data"]
for section in compare_section:
changes = compare_dict_value(loader_mode_config.get(section, {}), current_config.get(section, {}))
if changes:
logger.warning("Warning: Loader mode config and current config, `{}` are different:\n".format(section))

View File

@@ -1,290 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
import copy
import json
import yaml
import pickle
import gridfs
import pymongo
from pathlib import Path
from abc import abstractmethod
from .config import EstimatorConfigManager, ExperimentConfig
class Fetcher(object):
"""Sacred Experiments Fetcher"""
@abstractmethod
def _get_experiment(self, exp_name, exp_id):
"""Get experiment basic info with experiment and experiment id
:param exp_name: experiment name
:param exp_id: experiment id
:return: dict
Must contain keys: _id, experiment, info, stop_time.
Here is an example below for FileFetcher.
exp = {
'_id': exp_id, # experiment id
'path': path, # experiment result path
'experiment': {'name': exp_name}, # experiment
'info': info, # experiment config info
'stop_time': run.get('stop_time', None) # The time the experiment ended
}
"""
pass
@abstractmethod
def _list_experiments(self, exp_name=None):
"""Get experiment basic info list with experiment name
:param exp_name: experiment name
:return: list
"""
pass
@abstractmethod
def _iter_artifacts(self, experiment):
"""Get information about the data in the experiment results
:param experiment: `self._get_experiment` method result
:return: iterable
Each element contains two elements.
first element : data name
second element : data uri
"""
pass
@abstractmethod
def _load_data(self, uri):
"""Load data with uri
:param uri: data uri
:return: bytes
"""
pass
@staticmethod
def model_dict_to_buffer_list(model_dict):
"""
:param model_dict:
:return:
"""
model_list = []
is_static_model = False
if len(model_dict) == 1 and list(model_dict.keys())[0] == "model.bin":
is_static_model = True
model_list.append(list(model_dict.values())[0])
else:
sep = "model.bin_"
model_ids = list(map(lambda x: int(x.split(sep)[1]), model_dict.keys()))
min_id, max_id = min(model_ids), max(model_ids)
for i in range(min_id, max_id + 1):
model_key = sep + str(i)
model = model_dict.get(model_key, None)
if model is None:
print(
"WARNING: In Fetcher, {} is missing when the get model is in the get_experiment function.".format(
model_key
)
)
break
else:
model_list.append(model)
if is_static_model:
return model_list[0]
return model_list
def get_experiments(self, exp_name=None):
"""Get experiments with name.
:param exp_name: str
If `exp_name` is set to None, then all experiments will return.
:return: dict
Experiments info dict(Including experiment id and task_config to run the
experiment). Here is an example below.
{
'a_experiment': [
{
'id': '1',
'task_config': {...}
},
...
]
...
}
"""
res = dict()
for ex in self._list_experiments(exp_name):
name = ex["experiment"]["name"]
tmp = {
"id": ex["_id"],
"task_config": ex["info"].get("task_config", {}),
"ex_run_stop_time": ex.get("stop_time", None),
}
res.setdefault(name, []).append(tmp)
return res
def get_experiment(self, exp_name, exp_id, fields=None):
"""
:param exp_name:
:param exp_id:
:param fields: list
Experiment result fields, if fields is None, will get all fields.
Currently supported fields:
['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
:return: dict
"""
fields = copy.copy(fields)
ex = self._get_experiment(exp_name, exp_id)
results = dict()
model_dict = dict()
for name, uri in self._iter_artifacts(ex):
# When saving, use `sacred.experiment.add_artifact(filename)` , so `name` is os.path.basename(filename)
prefix = name.split(".")[0]
if fields and prefix not in fields:
continue
data = self._load_data(uri)
if prefix == "model":
model_dict[name] = data
else:
results[prefix] = pickle.loads(data)
# Sort model
if model_dict:
results["model"] = self.model_dict_to_buffer_list(model_dict)
# Info
results["task_config"] = ex["info"].get("task_config", {})
return results
def estimator_config_to_dict(self, exp_name, exp_id):
"""Save configuration to file
:param exp_name:
:param exp_id:
:return: config dict
"""
return self.get_experiment(exp_name, exp_id, fields=["task_config"])["task_config"]
class FileFetcher(Fetcher):
"""File Fetcher"""
def __init__(self, experiments_dir):
self.experiments_dir = Path(experiments_dir)
def _get_experiment(self, exp_name, exp_id):
path = self.experiments_dir / exp_name / "sacred" / str(exp_id)
info_path = path / "info.json"
run_path = path / "run.json"
if info_path.exists():
with info_path.open("r") as f:
info = json.load(f)
else:
info = {}
if run_path.exists():
with run_path.open("r") as f:
run = json.load(f)
else:
run = {}
exp = {
"_id": exp_id,
"path": path,
"experiment": {"name": exp_name},
"info": info,
"stop_time": run.get("stop_time", None),
}
return exp
def _list_experiments(self, exp_name=None):
runs = []
for path in self.experiments_dir.glob("{}/sacred/[!_]*".format(exp_name or "*")):
exp_name, exp_id = path.parents[1].name, path.name
runs.append(self._get_experiment(exp_name, exp_id))
return runs
def _iter_artifacts(self, experiment):
if experiment is None:
return []
for fname in experiment["path"].iterdir():
if fname.suffix == ".pkl" or ".bin" in fname.suffix:
name, uri = fname.name, str(fname)
yield name, uri
def _load_data(self, uri):
with open(uri, "rb") as f:
data = f.read()
return data
class MongoFetcher(Fetcher):
"""MongoDB Fetcher"""
def __init__(self, mongo_url, db_name):
self.mongo_url = mongo_url
self.db_name = db_name
self.client = None
self.db = None
self.runs = None
self.fs = None
self._setup_mongo_client()
def _setup_mongo_client(self):
self.client = pymongo.MongoClient(self.mongo_url)
self.db = self.client[self.db_name]
self.runs = self.db.runs
self.fs = gridfs.GridFS(self.db)
def _get_experiment(self, exp_name, exp_id):
return self.runs.find_one({"_id": exp_id})
def _list_experiments(self, exp_name=None):
if exp_name is None:
return self.runs.find()
return self.runs.find({"experiment.name": exp_name})
def _iter_artifacts(self, experiment):
if experiment is None:
return []
for artifact in experiment.get("artifacts", []):
name, uri = artifact["name"], artifact["file_id"]
yield name, uri
def _load_data(self, uri):
data = self.fs.get(uri).read()
return data
def create_fetcher_with_config(config_manager: EstimatorConfigManager, load_form_loader: bool = False):
"""Create fetcher with loader config
:param config_manager:
:param load_form_loader
:return:
"""
flag = ""
if load_form_loader:
flag = "loader_"
if config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_FILE_STORAGE:
return FileFetcher(eval("config_manager.ex_config.{}_dir".format("loader" if load_form_loader else "global")))
elif config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
return MongoFetcher(
mongo_url=eval("config_manager.ex_config.{}mongo_url".format(flag)),
db_name=eval("config_manager.ex_config.{}db_name".format(flag)),
)
else:
return NotImplementedError("Unkown Backend")

View File

@@ -1,585 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
import abc
import bisect
import logging
import pandas as pd
import numpy as np
from ...log import get_module_logger, TimeInspector
from ...data import D
from ...utils import parse_config, transform_end_date
from . import processor as processor_module
class BaseDataHandler(abc.ABC):
def __init__(self, processors=[], **kwargs):
"""
:param start_date:
:param end_date:
:param kwargs:
"""
# Set logger
self.logger = get_module_logger("DataHandler")
# init data using kwargs
self._init_kwargs(**kwargs)
# Setup data.
self.raw_df, self.feature_names, self.label_names = self._init_raw_df()
# Setup preprocessor
self.processors = []
for klass in processors:
if isinstance(klass, str):
try:
klass = getattr(processor_module, klass)
except:
raise ValueError("unknown Processor %s" % klass)
self.processors.append(klass(self.feature_names, self.label_names, **kwargs))
def _init_kwargs(self, **kwargs):
"""
init the kwargs of DataHandler
"""
pass
def _init_raw_df(self):
"""
init raw_df, feature_names, label_names of DataHandler
if the index of df_feature and df_label are not same, user need to overload this method to merge (e.g. inner, left, right merge).
"""
df_features = self.setup_feature()
feature_names = df_features.columns
df_labels = self.setup_label()
label_names = df_labels.columns
raw_df = df_features.merge(df_labels, left_index=True, right_index=True, how="left")
return raw_df, feature_names, label_names
def reset_label(self, df_labels):
for col in self.label_names:
del self.raw_df[col]
self.label_names = df_labels.columns
self.raw_df = self.raw_df.merge(df_labels, left_index=True, right_index=True, how="left")
def split_rolling_periods(
self,
train_start_date,
train_end_date,
validate_start_date,
validate_end_date,
test_start_date,
test_end_date,
rolling_period,
calendar_freq="day",
):
"""
Calculating the Rolling split periods, the period rolling on market calendar.
:param train_start_date:
:param train_end_date:
:param validate_start_date:
:param validate_end_date:
:param test_start_date:
:param test_end_date:
:param rolling_period: The market period of rolling
:param calendar_freq: The frequence of the market calendar
:yield: Rolling split periods
"""
def get_start_index(calendar, start_date):
start_index = bisect.bisect_left(calendar, start_date)
return start_index
def get_end_index(calendar, end_date):
end_index = bisect.bisect_right(calendar, end_date)
return end_index - 1
calendar = self.raw_df.index.get_level_values("datetime").unique()
train_start_index = get_start_index(calendar, pd.Timestamp(train_start_date))
train_end_index = get_end_index(calendar, pd.Timestamp(train_end_date))
valid_start_index = get_start_index(calendar, pd.Timestamp(validate_start_date))
valid_end_index = get_end_index(calendar, pd.Timestamp(validate_end_date))
test_start_index = get_start_index(calendar, pd.Timestamp(test_start_date))
test_end_index = test_start_index + rolling_period - 1
need_stop_split = False
bound_test_end_index = get_end_index(calendar, pd.Timestamp(test_end_date))
while not need_stop_split:
if test_end_index > bound_test_end_index:
test_end_index = bound_test_end_index
need_stop_split = True
yield (
calendar[train_start_index],
calendar[train_end_index],
calendar[valid_start_index],
calendar[valid_end_index],
calendar[test_start_index],
calendar[test_end_index],
)
train_start_index += rolling_period
train_end_index += rolling_period
valid_start_index += rolling_period
valid_end_index += rolling_period
test_start_index += rolling_period
test_end_index += rolling_period
def get_rolling_data(
self,
train_start_date,
train_end_date,
validate_start_date,
validate_end_date,
test_start_date,
test_end_date,
rolling_period,
calendar_freq="day",
):
# Set generator.
for period in self.split_rolling_periods(
train_start_date,
train_end_date,
validate_start_date,
validate_end_date,
test_start_date,
test_end_date,
rolling_period,
calendar_freq,
):
(
x_train,
y_train,
x_validate,
y_validate,
x_test,
y_test,
) = self.get_split_data(*period)
yield x_train, y_train, x_validate, y_validate, x_test, y_test
def get_split_data(
self,
train_start_date,
train_end_date,
validate_start_date,
validate_end_date,
test_start_date,
test_end_date,
):
"""
all return types are DataFrame
"""
## TODO: loc can be slow, expecially when we put it at the second level index.
if self.raw_df.index.names[0] == "instrument":
df_train = self.raw_df.loc(axis=0)[:, train_start_date:train_end_date]
df_validate = self.raw_df.loc(axis=0)[:, validate_start_date:validate_end_date]
df_test = self.raw_df.loc(axis=0)[:, test_start_date:test_end_date]
else:
df_train = self.raw_df.loc[train_start_date:train_end_date]
df_validate = self.raw_df.loc[validate_start_date:validate_end_date]
df_test = self.raw_df.loc[test_start_date:test_end_date]
TimeInspector.set_time_mark()
df_train, df_validate, df_test = self.setup_process_data(df_train, df_validate, df_test)
TimeInspector.log_cost_time("Finished setup processed data.")
x_train = df_train[self.feature_names]
y_train = df_train[self.label_names]
x_validate = df_validate[self.feature_names]
y_validate = df_validate[self.label_names]
x_test = df_test[self.feature_names]
y_test = df_test[self.label_names]
return x_train, y_train, x_validate, y_validate, x_test, y_test
def setup_process_data(self, df_train, df_valid, df_test):
"""
process the train, valid and test data
:return: the processed train, valid and test data.
"""
for processor in self.processors:
df_train, df_valid, df_test = processor(df_train, df_valid, df_test)
return df_train, df_valid, df_test
def get_origin_test_label_with_date(self, test_start_date, test_end_date, freq="day"):
"""Get origin test label
:param test_start_date: test start date
:param test_end_date: test end date
:param freq: freq
:return: pd.DataFrame
"""
test_end_date = transform_end_date(test_end_date, freq=freq)
return self.raw_df.loc[(slice(None), slice(test_start_date, test_end_date)), self.label_names]
@abc.abstractmethod
def setup_feature(self):
"""
Implement this method to load raw feature.
the format of the feature is below
return: df_features
"""
pass
@abc.abstractmethod
def setup_label(self):
"""
Implement this method to load and calculate label.
the format of the label is below
return: df_label
"""
pass
class QLibDataHandler(BaseDataHandler):
def __init__(self, start_date, end_date, *args, **kwargs):
# Dates.
self.start_date = start_date
self.end_date = end_date
super().__init__(*args, **kwargs)
def _init_kwargs(self, **kwargs):
# Instruments
instruments = kwargs.get("instruments", None)
if instruments is None:
market = kwargs.get("market", "csi500").lower()
data_filter_list = kwargs.get("data_filter_list", list())
self.instruments = D.instruments(market, filter_pipe=data_filter_list)
else:
self.instruments = instruments
# Config of features and labels
self._fields = kwargs.get("fields", [])
self._names = kwargs.get("names", [])
self._labels = kwargs.get("labels", [])
self._label_names = kwargs.get("label_names", [])
# Check arguments
assert len(self._fields) > 0, "features list is empty"
assert len(self._labels) > 0, "labels list is empty"
# Check end_date
# If test_end_date is -1 or greater than the last date, the last date is used
self.end_date = transform_end_date(self.end_date)
def setup_feature(self):
"""
Load the raw data.
return: df_features
"""
TimeInspector.set_time_mark()
if len(self._names) == 0:
names = ["F%d" % i for i in range(len(self._fields))]
else:
names = self._names
df_features = D.features(self.instruments, self._fields, self.start_date, self.end_date)
df_features.columns = names
TimeInspector.log_cost_time("Finished loading features.")
return df_features
def setup_label(self):
"""
Build up labels in df through users' method
:return: df_labels
"""
TimeInspector.set_time_mark()
if len(self._label_names) == 0:
label_names = ["LABEL%d" % i for i in range(len(self._labels))]
else:
label_names = self._label_names
df_labels = D.features(self.instruments, self._labels, self.start_date, self.end_date)
df_labels.columns = label_names
TimeInspector.log_cost_time("Finished loading labels.")
return df_labels
def parse_config_to_fields(config):
"""create factors from config
config = {
'kbar': {}, # whether to use some hard-code kbar features
'price': { # whether to use raw price features
'windows': [0, 1, 2, 3, 4], # use price at n days ago
'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
},
'volume': { # whether to use raw volume features
'windows': [0, 1, 2, 3, 4], # use volume at n days ago
},
'rolling': { # whether to use rolling operator based features
'windows': [5, 10, 20, 30, 60], # rolling windows size
'include': ['ROC', 'MA', 'STD'], # rolling operator to use
#if include is None we will use default operators
'exclude': ['RANK'], # rolling operator not to use
}
}
"""
fields = []
names = []
if "kbar" in config:
fields += [
"($close-$open)/$open",
"($high-$low)/$open",
"($close-$open)/($high-$low+1e-12)",
"($high-Greater($open, $close))/$open",
"($high-Greater($open, $close))/($high-$low+1e-12)",
"(Less($open, $close)-$low)/$open",
"(Less($open, $close)-$low)/($high-$low+1e-12)",
"(2*$close-$high-$low)/$open",
"(2*$close-$high-$low)/($high-$low+1e-12)",
]
names += [
"KMID",
"KLEN",
"KMID2",
"KUP",
"KUP2",
"KLOW",
"KLOW2",
"KSFT",
"KSFT2",
]
if "price" in config:
windows = config["price"].get("windows", range(5))
feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
for field in feature:
field = field.lower()
fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
names += [field.upper() + str(d) for d in windows]
if "volume" in config:
windows = config["volume"].get("windows", range(5))
fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
names += ["VOLUME" + str(d) for d in windows]
if "rolling" in config:
windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
include = config["rolling"].get("include", None)
exclude = config["rolling"].get("exclude", [])
# `exclude` in dataset config unnecessary filed
# `include` in dataset config necessary field
use = lambda x: x not in exclude and (include is None or x in include)
if use("ROC"):
fields += ["Ref($close, %d)/$close" % d for d in windows]
names += ["ROC%d" % d for d in windows]
if use("MA"):
fields += ["Mean($close, %d)/$close" % d for d in windows]
names += ["MA%d" % d for d in windows]
if use("STD"):
fields += ["Std($close, %d)/$close" % d for d in windows]
names += ["STD%d" % d for d in windows]
if use("BETA"):
fields += ["Slope($close, %d)/$close" % d for d in windows]
names += ["BETA%d" % d for d in windows]
if use("RSQR"):
fields += ["Rsquare($close, %d)" % d for d in windows]
names += ["RSQR%d" % d for d in windows]
if use("RESI"):
fields += ["Resi($close, %d)/$close" % d for d in windows]
names += ["RESI%d" % d for d in windows]
if use("MAX"):
fields += ["Max($high, %d)/$close" % d for d in windows]
names += ["MAX%d" % d for d in windows]
if use("LOW"):
fields += ["Min($low, %d)/$close" % d for d in windows]
names += ["MIN%d" % d for d in windows]
if use("QTLU"):
fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
names += ["QTLU%d" % d for d in windows]
if use("QTLD"):
fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
names += ["QTLD%d" % d for d in windows]
if use("RANK"):
fields += ["Rank($close, %d)" % d for d in windows]
names += ["RANK%d" % d for d in windows]
if use("RSV"):
fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
names += ["RSV%d" % d for d in windows]
if use("IMAX"):
fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
names += ["IMAX%d" % d for d in windows]
if use("IMIN"):
fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
names += ["IMIN%d" % d for d in windows]
if use("IMXD"):
fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
names += ["IMXD%d" % d for d in windows]
if use("CORR"):
fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
names += ["CORR%d" % d for d in windows]
if use("CORD"):
fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
names += ["CORD%d" % d for d in windows]
if use("CNTP"):
fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
names += ["CNTP%d" % d for d in windows]
if use("CNTN"):
fields += ["Mean($close<Ref($close, 1), %d)" % d for d in windows]
names += ["CNTN%d" % d for d in windows]
if use("CNTD"):
fields += ["Mean($close>Ref($close, 1), %d)-Mean($close<Ref($close, 1), %d)" % (d, d) for d in windows]
names += ["CNTD%d" % d for d in windows]
if use("SUMP"):
fields += [
"Sum(Greater($close-Ref($close, 1), 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMP%d" % d for d in windows]
if use("SUMN"):
fields += [
"Sum(Greater(Ref($close, 1)-$close, 0), %d)/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["SUMN%d" % d for d in windows]
if use("SUMD"):
fields += [
"(Sum(Greater($close-Ref($close, 1), 0), %d)-Sum(Greater(Ref($close, 1)-$close, 0), %d))"
"/(Sum(Abs($close-Ref($close, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["SUMD%d" % d for d in windows]
if use("VMA"):
fields += ["Mean($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VMA%d" % d for d in windows]
if use("VSTD"):
fields += ["Std($volume, %d)/($volume+1e-12)" % d for d in windows]
names += ["VSTD%d" % d for d in windows]
if use("WVMA"):
fields += [
"Std(Abs($close/Ref($close, 1)-1)*$volume, %d)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, %d)+1e-12)"
% (d, d)
for d in windows
]
names += ["WVMA%d" % d for d in windows]
if use("VSUMP"):
fields += [
"Sum(Greater($volume-Ref($volume, 1), 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["VSUMP%d" % d for d in windows]
if use("VSUMN"):
fields += [
"Sum(Greater(Ref($volume, 1)-$volume, 0), %d)/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d)
for d in windows
]
names += ["VSUMN%d" % d for d in windows]
if use("VSUMD"):
fields += [
"(Sum(Greater($volume-Ref($volume, 1), 0), %d)-Sum(Greater(Ref($volume, 1)-$volume, 0), %d))"
"/(Sum(Abs($volume-Ref($volume, 1)), %d)+1e-12)" % (d, d, d)
for d in windows
]
names += ["VSUMD%d" % d for d in windows]
return fields, names
class ConfigQLibDataHandler(QLibDataHandler):
config_template = {} # template
def __init__(self, start_date, end_date, processors=None, **kwargs):
if processors is None:
processors = ["ConfigSectionProcessor"] # default processor
super().__init__(start_date, end_date, processors, **kwargs)
def _init_kwargs(self, **kwargs):
config = self.config_template.copy()
if "config_update" in kwargs:
config.update(kwargs["config_update"])
fields, names = parse_config_to_fields(config)
kwargs["fields"] = fields
kwargs["names"] = names
if "labels" not in kwargs:
kwargs["labels"] = ["Ref($vwap, -2)/Ref($vwap, -1) - 1"]
super()._init_kwargs(**kwargs)
class ALPHA360(ConfigQLibDataHandler):
config_template = {
"price": {"windows": range(60)},
"volume": {"windows": range(60)},
}
class QLibDataHandlerV1(ConfigQLibDataHandler):
config_template = {
"kbar": {},
"price": {
"windows": [0],
"feature": ["OPEN", "HIGH", "LOW", "VWAP"],
},
"rolling": {},
}
def __init__(self, start_date, end_date, processors=None, **kwargs):
if processors is None:
processors = ["PanelProcessor"] # V1 default processor
super().__init__(start_date, end_date, processors, **kwargs)
def setup_label(self):
"""
load the labels df
:return: df_labels
"""
TimeInspector.set_time_mark()
df_labels = super().setup_label()
## calculate new labels
df_labels["LABEL1"] = df_labels["LABEL0"].groupby(level="datetime").apply(lambda x: (x - x.mean()) / x.std())
df_labels = df_labels.drop(["LABEL0"], axis=1)
TimeInspector.log_cost_time("Finished loading labels.")
return df_labels
class Alpha158(QLibDataHandlerV1):
config_template = {
"kbar": {},
"price": {
"windows": [0],
"feature": ["OPEN", "HIGH", "LOW", "CLOSE"],
},
"rolling": {},
}
def _init_kwargs(self, **kwargs):
kwargs["labels"] = ["Ref($close, -2)/Ref($close, -1) - 1"]
super(Alpha158, self)._init_kwargs(**kwargs)
# if __name__ == '__main__':
# import qlib
#
# qlib.init()
#
# handler = ALPHA80('2010-01-01', '2018-12-31')
# data = handler.get_split_data(
# pd.Timestamp('2010-01-01'), pd.Timestamp('2014-01-01'),
# pd.Timestamp('2015-01-01'), pd.Timestamp('2016-01-01'),
# pd.Timestamp('2017-01-01'), pd.Timestamp('2018-01-01'))
# print(data[0])
# data[0].to_pickle('alpha80.pkl')

View File

@@ -1,115 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import importlib
from ... import init
from .config import EstimatorConfigManager
from ...log import get_module_logger
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
args_parser = argparse.ArgumentParser(prog="estimator")
args_parser.add_argument(
"-c",
"--config_path",
required=True,
type=str,
help="json config path indicates where to load config.",
)
args = args_parser.parse_args()
class SacredExperiment(object):
def __init__(
self,
experiment_name,
experiment_dir,
observer_type="file_storage",
mongo_url=None,
db_name=None,
):
"""__init__
:param experiment_name: The name of the experiments.
:param experiment_dir: The directory to store all the results of the experiments(This is for file_storage).
:param observer_type: The observer to record the results: the `file_storage` or `mongo`
:param mongo_url: The mongo url(for mongo observer)
:param db_name: The mongo url(for mongo observer)
"""
self.experiment_name = experiment_name
self.experiment = Experiment(self.experiment_name)
self.experiment_dir = experiment_dir
self.experiment.logger = get_module_logger("Sacred")
self.observer_type = observer_type
self.mongo_db_url = mongo_url
self.mongo_db_name = db_name
self._setup_experiment()
def _setup_experiment(self):
if self.observer_type == "file_storage":
file_storage_observer = FileStorageObserver.create(basedir=self.experiment_dir)
self.experiment.observers.append(file_storage_observer)
elif self.observer_type == "mongo":
mongo_observer = MongoObserver.create(url=self.mongo_db_url, db_name=self.mongo_db_name)
self.experiment.observers.append(mongo_observer)
else:
raise NotImplementedError("Unsupported observer type: {}".format(self.observer_type))
def add_artifact(self, filename):
self.experiment.add_artifact(filename)
def add_info(self, key, value):
self.experiment.info[key] = value
def main_wrapper(self, func):
return self.experiment.main(func)
def config_wrapper(self, func):
return self.experiment.config(func)
CONFIG_MANAGER = EstimatorConfigManager(args.config_path)
ex = SacredExperiment(
CONFIG_MANAGER.ex_config.name,
CONFIG_MANAGER.ex_config.sacred_dir,
observer_type=CONFIG_MANAGER.ex_config.observer_type,
mongo_url=CONFIG_MANAGER.ex_config.mongo_url,
db_name=CONFIG_MANAGER.ex_config.db_name,
)
# qlib init
init(
provider_uri=CONFIG_MANAGER.qlib_data_config.provider_uri,
mount_path=CONFIG_MANAGER.qlib_data_config.mount_path,
auto_mount=CONFIG_MANAGER.qlib_data_config.auto_mount,
region=CONFIG_MANAGER.qlib_data_config.region,
**CONFIG_MANAGER.qlib_data_config.args
)
@ex.main_wrapper
def _main():
# 1. Get estimator class.
estimator_class = getattr(
importlib.import_module(".estimator", package="qlib.contrib.estimator"),
"Estimator",
)
# 2. Init estimator.
estimator = estimator_class(CONFIG_MANAGER, ex)
estimator.run()
def run():
ex.experiment.run()
if __name__ == "__main__":
run()

View File

@@ -1,249 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
import numpy as np
import pandas as pd
from ...log import TimeInspector
EPS = 1e-12
class Processor(abc.ABC):
def __init__(self, feature_names, label_names, **kwargs):
self.feature_names = feature_names
self.label_names = label_names
@abc.abstractmethod
def __call__(self, df_train, df_valid, df_test):
pass
class PanelProcessor(Processor):
"""Panel Preprocessor"""
STD_NORM = "Std"
MINMAX_NORM = "MinMax"
def __init__(self, feature_names, label_names, **kwargs):
super().__init__(feature_names, label_names)
# Options.
self.dropna_label = kwargs.get("dropna_label", True)
self.dropna_feature = kwargs.get("dropna_feature", False)
self.normalize_method = kwargs.get("normalize_method", None)
self.replace_inf = kwargs.get("replace_inf_feature", False)
def __call__(self, df_train, df_valid, df_test):
"""
Preprocess the data
:param df: the dataframe to process data.
"""
# Drop null labels.
if self.dropna_label:
df_train, df_valid, df_test = self._process_drop_null_label(df_train, df_valid, df_test)
# Dropna if need.
if self.dropna_feature:
df_train, df_valid, df_test = self._process_drop_null_feature(df_train, df_valid, df_test)
# replace the 'inf' with the mean the corresponding dimension
if self.replace_inf:
df_train, df_valid, df_test = self._process_replace_inf_feature(df_train, df_valid, df_test)
# normalize data in given method.
if self.normalize_method is not None:
df_train, df_valid, df_test = self._process_normalize_feature(df_train, df_valid, df_test)
return df_train, df_valid, df_test
def _process_drop_null_label(self, df_train, df_valid, df_test):
"""
Drop null labels.
"""
TimeInspector.set_time_mark()
df_train = df_train.dropna(subset=self.label_names)
df_valid = df_valid.dropna(subset=self.label_names)
# The test data's label is Unkown. They can not be seen when preprocessing
TimeInspector.log_cost_time("Finished dropping null labels.")
return df_train, df_valid, df_test
def _process_drop_null_feature(self, df_train, df_valid, df_test):
"""
Drop data which contain null features if needed.
"""
# TODO - `Pandas.dropna` is a low performance method.
TimeInspector.set_time_mark()
df_train = df_train.dropna(subset=self.feature_names)
df_valid = df_valid.dropna(subset=self.feature_names)
df_test = df_test.dropna(subset=self.feature_names)
TimeInspector.log_cost_time("Finished dropping nan.")
return df_train, df_valid, df_test
def _process_replace_inf_feature(self, df_train, df_valid, df_test):
"""
replace the 'inf' in feature with the mean of this dimension.
"""
TimeInspector.set_time_mark()
def replace_inf(data):
def process_inf(df):
for col in df.columns:
df[col] = df[col].replace([np.inf, -np.inf], df[col][~np.isinf(df[col])].mean())
return df
data = data.groupby("datetime").apply(process_inf)
data.sort_index(inplace=True)
return data
df_train = replace_inf(df_train)
df_valid = replace_inf(df_valid)
df_test = replace_inf(df_test)
TimeInspector.log_cost_time("Finished replace inf.")
return df_train, df_valid, df_test
def _process_normalize_feature(self, df_train, df_valid, df_test):
"""
Normalize data if needed, we provide two method now: min-max normalization and standard normalization.
"""
TimeInspector.set_time_mark()
if self.normalize_method == self.MINMAX_NORM:
min_train = np.nanmin(df_train[self.feature_names].values, axis=0)
max_train = np.nanmax(df_train[self.feature_names].values, axis=0)
ignore = min_train == max_train
def normalize(x, min_train=min_train, max_train=max_train, ignore=ignore):
if (~ignore).all():
return (x - min_train) / (max_train - min_train)
for i in range(ignore.size):
if not ignore[i]:
x[i] = (x[i] - min_train) / (max_train - min_train)
return x
elif self.normalize_method == self.STD_NORM:
mean_train = np.nanmean(df_train[self.feature_names].values, axis=0)
std_train = np.nanstd(df_train[self.feature_names].values, axis=0)
ignore = std_train == 0
def normalize(x, mean_train=mean_train, std_train=std_train, ignore=ignore):
if (~ignore).all():
return (x - mean_train) / std_train
for i in range(ignore.size):
if not ignore[i]:
x[i] = (x[i] - mean_train) / std_train
return x
else:
raise ValueError("Normalize method {} is not allowed".format(self.normalize_method))
df_train.loc(axis=1)[self.feature_names] = normalize(df_train[self.feature_names].values)
df_valid.loc(axis=1)[self.feature_names] = normalize(df_valid[self.feature_names].values)
df_test.loc(axis=1)[self.feature_names] = normalize(df_test[self.feature_names].values)
TimeInspector.log_cost_time("Finished normalizing data.")
return df_train, df_valid, df_test
class ConfigSectionProcessor(Processor):
def __init__(self, feature_names, label_names, **kwargs):
super().__init__(feature_names, label_names)
# Options
self.fillna_feature = kwargs.get("fillna_feature", True)
self.fillna_label = kwargs.get("fillna_label", True)
self.clip_feature_outlier = kwargs.get("clip_feature_outlier", False)
self.shrink_feature_outlier = kwargs.get("shrink_feature_outlier", True)
self.clip_label_outlier = kwargs.get("clip_label_outlier", False)
def __call__(self, *args):
return [self._transform(x) for x in args]
def _transform(self, df):
def _label_norm(x):
x = x - x.mean() # copy
x /= x.std()
if self.clip_label_outlier:
x.clip(-3, 3, inplace=True)
if self.fillna_label:
x.fillna(0, inplace=True)
return x
def _feature_norm(x):
x = x - x.median() # copy
x /= x.abs().median() * 1.4826
if self.clip_feature_outlier:
x.clip(-3, 3, inplace=True)
if self.shrink_feature_outlier:
x.where(x <= 3, 3 + (x - 3).div(x.max() - 3) * 0.5, inplace=True)
x.where(x >= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
if self.fillna_feature:
x.fillna(0, inplace=True)
return x
TimeInspector.set_time_mark()
# Copy
df_new = df.copy()
# Label
cols = df.columns[df.columns.str.contains("^LABEL")]
df_new[cols] = df[cols].groupby(level="datetime").apply(_label_norm)
# Features
cols = df.columns[df.columns.str.contains("^KLEN|^KLOW|^KUP")]
df_new[cols] = df[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^KLOW2|^KUP2")]
df_new[cols] = df[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
_cols = [
"KMID",
"KSFT",
"OPEN",
"HIGH",
"LOW",
"CLOSE",
"VWAP",
"ROC",
"MA",
"BETA",
"RESI",
"QTLU",
"QTLD",
"RSV",
"SUMP",
"SUMN",
"SUMD",
"VSUMP",
"VSUMN",
"VSUMD",
]
pat = "|".join(["^" + x for x in _cols])
cols = df.columns[df.columns.str.contains(pat) & (~df.columns.isin(["HIGH0", "LOW0"]))]
df_new[cols] = df[cols].groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
df_new[cols] = df[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^RSQR")]
df_new[cols] = df[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^MAX|^HIGH0")]
df_new[cols] = df[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^MIN|^LOW0")]
df_new[cols] = df[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^CORR|^CORD")]
df_new[cols] = df[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
cols = df.columns[df.columns.str.contains("^WVMA")]
df_new[cols] = df[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
TimeInspector.log_cost_time("Finished preprocessing data.")
return df_new

View File

@@ -1,317 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# coding=utf-8
from abc import abstractmethod
import pandas as pd
import numpy as np
from scipy.stats import pearsonr
from ...log import get_module_logger, TimeInspector
from .handler import BaseDataHandler
from .launcher import CONFIG_MANAGER
from .fetcher import create_fetcher_with_config
from ...utils import drop_nan_by_y_index, transform_end_date
class BaseTrainer(object):
def __init__(self, model_class, model_save_path, model_args, data_handler: BaseDataHandler, sacred_ex, **kwargs):
# 1. Model.
self.model_class = model_class
self.model_save_path = model_save_path
self.model_args = model_args
# 2. Data handler.
self.data_handler = data_handler
# 3. Sacred ex.
self.ex = sacred_ex
# 4. Logger.
self.logger = get_module_logger("Trainer")
# 5. Data time
self.train_start_date = kwargs.get("train_start_date", None)
self.train_end_date = kwargs.get("train_end_date", None)
self.validate_start_date = kwargs.get("validate_start_date", None)
self.validate_end_date = kwargs.get("validate_end_date", None)
self.test_start_date = kwargs.get("test_start_date", None)
self.test_end_date = transform_end_date(kwargs.get("test_end_date", None))
@abstractmethod
def train(self):
"""
Implement this method indicating how to train a model.
"""
pass
@abstractmethod
def load(self):
"""
Implement this method indicating how to restore a model and the data.
"""
pass
@abstractmethod
def get_test_pred(self):
"""
Implement this method indicating how to get prediction result(s) from a model.
"""
pass
def get_test_performance(self):
"""
Implement this method indicating how to get the performance of the model.
"""
raise NotImplementedError(f"Please implement `get_test_performance`")
def get_test_score(self):
"""
Override this method to transfer the predict result(s) into the score of the stock.
Note: If this is a multi-label training, you need to transfer predict labels into one score.
Or you can just use the result of `get_test_pred()` (you can also process the result) if this is one label training.
We use the first column of the result of `get_test_pred()` as default method (regard it as one label training).
"""
pred = self.get_test_pred()
pred_score = pd.DataFrame(index=pred.index)
pred_score["score"] = pred.iloc(axis=1)[0]
return pred_score
class StaticTrainer(BaseTrainer):
def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
super(StaticTrainer, self).__init__(model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs)
self.model = None
split_data = self.data_handler.get_split_data(
self.train_start_date,
self.train_end_date,
self.validate_start_date,
self.validate_end_date,
self.test_start_date,
self.test_end_date,
)
(
self.x_train,
self.y_train,
self.x_validate,
self.y_validate,
self.x_test,
self.y_test,
) = split_data
def train(self):
TimeInspector.set_time_mark()
model = self.model_class(**self.model_args)
if CONFIG_MANAGER.ex_config.finetune:
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
loader_model = fetcher.get_experiment(
exp_name=CONFIG_MANAGER.ex_config.loader_name,
exp_id=CONFIG_MANAGER.ex_config.loader_id,
fields=["model"],
)["model"]
if isinstance(loader_model, list):
model_index = (
-1
if CONFIG_MANAGER.ex_config.loader_model_index is None
else CONFIG_MANAGER.ex_config.loader_model_index
)
loader_model = loader_model[model_index]
model.load(loader_model)
model.finetune(self.x_train, self.y_train, self.x_validate, self.y_validate)
else:
model.fit(self.x_train, self.y_train, self.x_validate, self.y_validate)
model.save(self.model_save_path)
self.ex.add_artifact(self.model_save_path)
self.model = model
TimeInspector.log_cost_time("Finished training model.")
def load(self):
model = self.model_class(**self.model_args)
# Load model
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
loader_model = fetcher.get_experiment(
exp_name=CONFIG_MANAGER.ex_config.loader_name,
exp_id=CONFIG_MANAGER.ex_config.loader_id,
fields=["model"],
)["model"]
if isinstance(loader_model, list):
model_index = (
-1
if CONFIG_MANAGER.ex_config.loader_model_index is None
else CONFIG_MANAGER.ex_config.loader_model_index
)
loader_model = loader_model[model_index]
model.load(loader_model)
# Save model, after load, if you don't save the model, the result of this experiment will be no model
model.save(self.model_save_path)
self.ex.add_artifact(self.model_save_path)
self.model = model
def get_test_pred(self):
pred = self.model.predict(self.x_test)
pred = pd.DataFrame(pred, index=self.x_test.index, columns=self.y_test.columns)
return pred
def get_test_performance(self):
try:
model_score = self.model.score(self.x_test, self.y_test)
except NotImplementedError:
model_score = None
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
x_test, y_test, __ = drop_nan_by_y_index(self.x_test, self.y_test)
pred_test = self.model.predict(x_test)
model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
return performance
class RollingTrainer(BaseTrainer):
def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
super(RollingTrainer, self).__init__(
model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs
)
self.rolling_period = kwargs.get("rolling_period", 60)
self.models = []
self.rolling_data = []
self.all_x_test = []
self.all_y_test = []
for data in self.data_handler.get_rolling_data(
self.train_start_date,
self.train_end_date,
self.validate_start_date,
self.validate_end_date,
self.test_start_date,
self.test_end_date,
self.rolling_period,
):
self.rolling_data.append(data)
__, __, __, __, x_test, y_test = data
self.all_x_test.append(x_test)
self.all_y_test.append(y_test)
def train(self):
# 1. Get total data parts.
# total_data_parts = self.data_handler.total_data_parts
# self.logger.warning('Total numbers of model are: {}, start training models...'.format(total_data_parts))
if CONFIG_MANAGER.ex_config.finetune:
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
loader_model = fetcher.get_experiment(
exp_name=CONFIG_MANAGER.ex_config.loader_name,
exp_id=CONFIG_MANAGER.ex_config.loader_id,
fields=["model"],
)["model"]
loader_model_index = CONFIG_MANAGER.ex_config.loader_model_index
previous_model_path = ""
# 2. Rolling train.
for (
index,
(x_train, y_train, x_validate, y_validate, x_test, y_test),
) in enumerate(self.rolling_data):
TimeInspector.set_time_mark()
model = self.model_class(**self.model_args)
if CONFIG_MANAGER.ex_config.finetune:
# Finetune model
if loader_model_index is None and isinstance(loader_model, list):
try:
model.load(loader_model[index])
except IndexError:
# Load model by previous_model_path
with open(previous_model_path, "rb") as fp:
model.load(fp)
model.finetune(x_train, y_train, x_validate, y_validate)
else:
if index == 0:
loader_model = (
loader_model[loader_model_index] if isinstance(loader_model, list) else loader_model
)
model.load(loader_model)
else:
with open(previous_model_path, "rb") as fp:
model.load(fp)
model.finetune(x_train, y_train, x_validate, y_validate)
else:
model.fit(x_train, y_train, x_validate, y_validate)
model_save_path = "{}_{}".format(self.model_save_path, index)
model.save(model_save_path)
previous_model_path = model_save_path
self.ex.add_artifact(model_save_path)
self.models.append(model)
TimeInspector.log_cost_time("Finished training model: {}.".format(index + 1))
def load(self):
"""
Load the data and the model
"""
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
loader_model = fetcher.get_experiment(
exp_name=CONFIG_MANAGER.ex_config.loader_name,
exp_id=CONFIG_MANAGER.ex_config.loader_id,
fields=["model"],
)["model"]
for index in range(len(self.all_x_test)):
model = self.model_class(**self.model_args)
model.load(loader_model[index])
# Save model
model_save_path = "{}_{}".format(self.model_save_path, index)
model.save(model_save_path)
self.ex.add_artifact(model_save_path)
self.models.append(model)
def get_test_pred(self):
"""
Predict the score on test data with the models.
Please ensure the models and data are loaded before call this score.
:return: the predicted scores for the pred
"""
pred_df_list = []
y_test_columns = self.all_y_test[0].columns
# Start iteration.
for model, x_test in zip(self.models, self.all_x_test):
pred = model.predict(x_test)
pred_df = pd.DataFrame(pred, index=x_test.index, columns=y_test_columns)
pred_df_list.append(pred_df)
return pd.concat(pred_df_list)
def get_test_performance(self):
"""
Get the performances of the models
:return: the performances of models
"""
pred_test_list = []
y_test_list = []
scorer = self.models[0]._scorer
for model, x_test, y_test in zip(self.models, self.all_x_test, self.all_y_test):
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
pred_test_list.append(model.predict(x_test))
y_test_list.append(np.squeeze(y_test.values))
pred_test_array = np.concatenate(pred_test_list, axis=0)
y_test_array = np.concatenate(y_test_list, axis=0)
model_score = scorer(y_test_array, pred_test_array)
model_pearsonr = pearsonr(np.ravel(y_test_array), np.ravel(pred_test_array))[0]
performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
return performance

View File

Some files were not shown because too many files have changed in this diff Show More