Merge branch 'main' into dnn_drop
81
README.md
@@ -9,7 +9,7 @@
|
||||
|
||||
|
||||
<p align="center">
|
||||
<img src="http://fintech.msra.cn/images/logo/1.png" />
|
||||
<img src="http://fintech.msra.cn/images_v060/logo/1.png" />
|
||||
</p>
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [Quant Model Zoo](#quant-model-zoo)
|
||||
- [Run a single model](#run-a-single-model)
|
||||
- [Run multiple models](#run-multiple-models)
|
||||
- [Quant Dataset Zoo](#quant-dataset-zoo)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
@@ -39,19 +41,17 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="http://fintech.msra.cn/images/framework.png" />
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png" />
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Data layer` | `DataServer` focuses on providing high-performance infrastructure for users to manage and retrieve raw data. `DataEnhancement` will preprocess the data and provide the best dataset to be fed into the models. |
|
||||
| `Interday Model` | `Interday model` focuses on producing prediction scores (aka. _alpha_). Models are trained by `Model Creator` and managed by `Model Manager`. Users could choose one or multiple models for prediction. Multiple models could be combined with `Ensemble` module. |
|
||||
| `Interday Strategy` | `Portfolio Generator` will take prediction scores as input and output the orders based on the current position to achieve the target portfolio. |
|
||||
| `Intraday Trading` | `Order Executor` is responsible for executing orders output by `Interday Strategy` and returning the executed results. |
|
||||
| `Analysis` | Users could get a detailed analysis report of forecasting signals and portfolios in this part. |
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
* The modules with hand-drawn style are under development and will be released in the future.
|
||||
* The modules with dashed borders are highly user-customizable and extendible.
|
||||
@@ -139,17 +139,20 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
|
||||
```bash
|
||||
|
||||
risk
|
||||
excess_return_without_cost mean 0.000675
|
||||
std 0.005456
|
||||
annualized_return 0.170077
|
||||
information_ratio 1.963824
|
||||
max_drawdown -0.063646
|
||||
excess_return_with_cost mean 0.000479
|
||||
std 0.005453
|
||||
annualized_return 0.120776
|
||||
information_ratio 1.395116
|
||||
max_drawdown -0.071216
|
||||
'The following are analysis results of the excess return without cost.'
|
||||
risk
|
||||
mean 0.000708
|
||||
std 0.005626
|
||||
annualized_return 0.178316
|
||||
information_ratio 1.996555
|
||||
max_drawdown -0.081806
|
||||
'The following are analysis results of the excess return with cost.'
|
||||
risk
|
||||
mean 0.000512
|
||||
std 0.005626
|
||||
annualized_return 0.128982
|
||||
information_ratio 1.444287
|
||||
max_drawdown -0.091078
|
||||
|
||||
|
||||
|
||||
@@ -159,19 +162,19 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
|
||||
- Forecasting signal (model prediction) analysis
|
||||
- Cumulative Return of groups
|
||||

|
||||

|
||||
- Return distribution
|
||||

|
||||

|
||||
- Information Coefficient (IC)
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
- Auto Correlation of forecasting signal (model prediction)
|
||||

|
||||

|
||||
|
||||
- Portfolio analysis
|
||||
- Backtest return
|
||||

|
||||

|
||||
<!--
|
||||
- Score IC
|
||||

|
||||
@@ -187,7 +190,25 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# Quant Model Zoo
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
|
||||
## Run a single model
|
||||
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
|
||||
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
|
||||
|
||||
The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored.
|
||||
|
||||
Here is an example of running all the models for 10 iterations:
|
||||
```python
|
||||
python run_all_model.py 10
|
||||
```
|
||||
|
||||
It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
|
||||
@@ -196,10 +217,12 @@ Here is a list of models built on `Qlib`.
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
|
||||
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
|
||||
- [TabNet based on pytorch](qlib/contrib/model/tabnet.py)
|
||||
- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
|
||||
<!-- - [TFT based on tensorflow](examples/benchmarks/TFT/tft.py) -->
|
||||
- [HATs based on pytorch](qlib/contrib/model/pytorch_hats.py)
|
||||
- [TFT based on tensorflow](examples/benchmarks/TFT/tft.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
|
||||
BIN
docs/_static/img/analysis/analysis_model_IC.png
vendored
|
Before Width: | Height: | Size: 40 KiB After Width: | Height: | Size: 33 KiB |
BIN
docs/_static/img/analysis/analysis_model_NDQ.png
vendored
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 23 KiB |
|
Before Width: | Height: | Size: 52 KiB After Width: | Height: | Size: 47 KiB |
|
Before Width: | Height: | Size: 66 KiB After Width: | Height: | Size: 63 KiB |
|
Before Width: | Height: | Size: 17 KiB After Width: | Height: | Size: 16 KiB |
|
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 16 KiB |
BIN
docs/_static/img/analysis/report.png
vendored
|
Before Width: | Height: | Size: 163 KiB After Width: | Height: | Size: 160 KiB |
|
Before Width: | Height: | Size: 53 KiB After Width: | Height: | Size: 46 KiB |
BIN
docs/_static/img/analysis/risk_analysis_bar.png
vendored
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 13 KiB |
|
Before Width: | Height: | Size: 56 KiB After Width: | Height: | Size: 54 KiB |
|
Before Width: | Height: | Size: 57 KiB After Width: | Height: | Size: 53 KiB |
BIN
docs/_static/img/analysis/risk_analysis_std.png
vendored
|
Before Width: | Height: | Size: 47 KiB After Width: | Height: | Size: 47 KiB |
BIN
docs/_static/img/analysis/score_ic.png
vendored
|
Before Width: | Height: | Size: 105 KiB After Width: | Height: | Size: 102 KiB |
BIN
docs/_static/img/framework.png
vendored
|
Before Width: | Height: | Size: 205 KiB After Width: | Height: | Size: 271 KiB |
@@ -1,7 +1,7 @@
|
||||
.. _data:
|
||||
|
||||
================================
|
||||
Data Layer: Data Framework&Usage
|
||||
Data Layer: Data Framework & Usage
|
||||
================================
|
||||
|
||||
Introduction
|
||||
@@ -15,7 +15,9 @@ The introduction of ``Data Layer`` includes the following parts.
|
||||
|
||||
- Data Preparation
|
||||
- Data API
|
||||
- Data Loader
|
||||
- Data Handler
|
||||
- Dataset
|
||||
- Cache
|
||||
- Data and Cache File Structure
|
||||
|
||||
@@ -31,13 +33,19 @@ Such data will be stored with filename suffix `.bin` (We'll call them `.bin` fil
|
||||
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the dataset as follows.
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
After running the above command, users can find china-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory.
|
||||
In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
|
||||
After running the above command, users can find china-stock and us-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
|
||||
|
||||
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
|
||||
|
||||
@@ -49,12 +57,45 @@ Converting CSV Format into Qlib Format
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert data in CSV format into `.bin` files (Qlib format).
|
||||
|
||||
|
||||
Users can download the china-stock data in CSV format as follows for reference to the CSV format.
|
||||
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
|
||||
- Name the CSV file after a stock: `SH600000.csv`, `AAPL.csv` (not case sensitive).
|
||||
|
||||
- CSV file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --symbol_field_name symbol
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
symbol,close
|
||||
SH600000,120
|
||||
|
||||
- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/dump_bin.py dump_all ... --date_field_name date
|
||||
|
||||
where the data are in the following format:
|
||||
|
||||
.. code-block::
|
||||
|
||||
symbol,date,close,open,volume
|
||||
SH600000,2020-11-01,120,121,12300000
|
||||
SH600000,2020-11-02,123,120,12300000
|
||||
|
||||
|
||||
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
|
||||
|
||||
@@ -62,6 +103,12 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv
|
||||
|
||||
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
|
||||
|
||||
For other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python dump_bin.py dump_all --help
|
||||
|
||||
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.
|
||||
|
||||
.. note::
|
||||
@@ -97,9 +144,8 @@ China-Stock Mode & US-Stock Mode
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=REG_CN)
|
||||
|
||||
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` does not provide a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Prepare data in CSV format
|
||||
- Convert data from CSV format to Qlib format, please refer to section `Converting CSV Format into Qlib Format <#converting-csv-format-into-qlib-format>`_.
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in US-stock mode
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
@@ -141,68 +187,97 @@ Filter
|
||||
Expression dynamic instrument filter. Filter the instruments based on a certain expression. An expression rule indicating a certain feature field is required.
|
||||
|
||||
- `basic features filter`: rule_expression = '$close/$open>5'
|
||||
- `cross-sectional features filter` : rule_expression = '$rank($close)<10'
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
|
||||
Reference
|
||||
-------------
|
||||
|
||||
To know more about ``Data API``, please refer to `Data API <../reference/api.html#data>`_.
|
||||
|
||||
|
||||
Data Loader
|
||||
=================
|
||||
|
||||
``Data Loader`` in ``Qlib`` is designed to load raw data from the original data source. It will be loaded and used in the ``Data Handler`` module.
|
||||
|
||||
QlibDataLoader
|
||||
---------------
|
||||
|
||||
The ``QlibDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from the data source.
|
||||
|
||||
Interface
|
||||
------------
|
||||
|
||||
Here are some interfaces of the ``QlibDataLoader`` class:
|
||||
|
||||
.. autoclass:: qlib.data.dataset.loader.QlibDataLoader
|
||||
:members: load, load_group_df
|
||||
|
||||
API
|
||||
-----------
|
||||
|
||||
To know more about ``Data Loader``, please refer to `Data Loader API <../reference/api.html#module-qlib.data.dataset.loader>`_.
|
||||
|
||||
|
||||
Data Handler
|
||||
=================
|
||||
|
||||
Users can use ``Data Handler`` in an automatic workflow by ``Estimator``, refer to `Estimator: Workflow Management <estimator.html>`_ for more details.
|
||||
The ``Data Handler`` module in ``Qlib`` is designed to handler those common data processing methods which will be used by most of the models.
|
||||
|
||||
Also, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data(standardization, remove NaN, etc.) and build datasets. It is a subclass of ``qlib.data.dataset.handler.DataHandlerLP``, which provides some interfaces as follows.
|
||||
Users can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management <workflow.html>`_ for more details.
|
||||
|
||||
Base Class & Interface
|
||||
DataHandlerLP
|
||||
--------------
|
||||
|
||||
In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
|
||||
|
||||
In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some leanable ``Processors`` which can learn the parameters of data processing. When new data comes in, these `trained` ``Processors`` can then infer on the new data and thus processing real-time data in an efficient way. More information about ``Processors`` will be listed in the next subsection.
|
||||
|
||||
|
||||
Interface
|
||||
----------------------
|
||||
|
||||
Qlib provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_, which provides the following interfaces:
|
||||
Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
|
||||
- `load_feature`
|
||||
Implement the interface to load the data features.
|
||||
|
||||
- `load_label`
|
||||
Implement the interface to load the data labels and calculate the users' labels.
|
||||
|
||||
- `setup_processed_data`
|
||||
Implement the interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on.
|
||||
|
||||
Qlib also provides two functions to help users init the data handler, users can override them for users' needs.
|
||||
|
||||
- `_init_raw_data`
|
||||
Users can init the raw df, feature names, and label names of data handler in this function.
|
||||
If the index of feature df and label df are not the same, users need to override this method to merge them (e.g. inner, left, right merge).
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
|
||||
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
|
||||
|
||||
|
||||
Usage
|
||||
--------------
|
||||
Processor
|
||||
----------
|
||||
|
||||
``Data Handler`` can be used as a single module, which provides the following mehtods:
|
||||
The ``Processor`` module in ``Qlib`` is designed to be learnable and it is responsible for handling data processing such as `normalization` and `drop none/nan features/labels`.
|
||||
|
||||
- `get_split_data`
|
||||
- According to the start and end dates, return features and labels of the pandas DataFrame type used for the 'Model'
|
||||
|
||||
- `get_rolling_data`
|
||||
- According to the start and end dates, and `rolling_period`, an iterator is returned, which can be used to traverse the features and labels used for rolling.
|
||||
``Qlib`` provides the following ``Processors``:
|
||||
|
||||
- ``DropnaProcessor``: `processor` that drops N/A features.
|
||||
- ``DropnaLabel``: `processor` that drops N/A labels.
|
||||
- ``TanhProcess``: `processor` that uses `tanh` to process noise data.
|
||||
- ``ProcessInf``: `processor` that handles infinity values, it will be replaces by the mean of the column.
|
||||
- ``Fillna``: `processor` that handles N/A values, which will fill the N/A value by 0 or other given number.
|
||||
- ``MinMaxNorm``: `processor` that applies min-max normalization.
|
||||
- ``ZscoreNorm``: `processor` that applies z-score normalization.
|
||||
- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.
|
||||
- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.
|
||||
- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
|
||||
|
||||
Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link <https://github.com/microsoft/qlib/blob/main/qlib/data/dataset/processor.py>`_).
|
||||
|
||||
To know more about ``Processor``, please refer to `Processor API <../reference/api.html#module-qlib.data.dataset.processor>`_.
|
||||
|
||||
Example
|
||||
--------------
|
||||
|
||||
``Data Handler`` can be run with ``estimator`` by modifying the configuration file, and can also be used as a single module.
|
||||
``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module.
|
||||
|
||||
Know more about how to run ``Data Handler`` with ``Estimator``, please refer to `Estimator: Workflow Management <estimator.html>`_
|
||||
Know more about how to run ``Data Handler`` with ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_
|
||||
|
||||
Qlib provides implemented data handler `Alpha158`. The following example shows how to run `Alpha158` as a single module.
|
||||
|
||||
@@ -211,45 +286,55 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
import qlib
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"dropna_label": True,
|
||||
"start_date": "2007-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"market": "csi300",
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi300",
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2007-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
}
|
||||
if __name__ == "__main__":
|
||||
qlib.init()
|
||||
h = Alpha158(**data_handler_config)
|
||||
|
||||
exampleDataHandler = Alpha158(**DATA_HANDLER_CONFIG)
|
||||
# get all the columns of the data
|
||||
print(h.get_cols())
|
||||
|
||||
# example of 'get_split_data'
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = exampleDataHandler.get_split_data(**TRAINER_CONFIG)
|
||||
# fetch all the labels
|
||||
print(h.fetch(col_set="label"))
|
||||
|
||||
# example of 'get_rolling_data'
|
||||
|
||||
for (x_train, y_train, x_validate, y_validate, x_test, y_test) in exampleDataHandler.get_rolling_data(**TRAINER_CONFIG):
|
||||
print(x_train, y_train, x_validate, y_validate, x_test, y_test)
|
||||
|
||||
|
||||
.. note:: (x_train, y_train, x_validate, y_validate, x_test, y_test) can be used as arguments for the `fit`, `predic``, and `score` methods of the ``Interday Model`` , please refer to `Model <model.html#base-class-interface>`_.
|
||||
|
||||
Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``.
|
||||
# fetch all the features
|
||||
print(h.fetch(col_set="feature"))
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.data.dataset.handler>`_.
|
||||
|
||||
|
||||
Dataset
|
||||
=================
|
||||
|
||||
The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
|
||||
|
||||
The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the rights to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``DNN`` will break down on such data.
|
||||
|
||||
The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most important interface of the class:
|
||||
|
||||
.. autoclass:: qlib.data.dataset.__init__.DatasetH
|
||||
:members:
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
|
||||
|
||||
|
||||
|
||||
Cache
|
||||
==========
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ Interday Model: Model Training & Prediction
|
||||
Introduction
|
||||
===================
|
||||
|
||||
``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``Estimator``, please refer to `Estimator: Workflow Management <estimator.html>`_.
|
||||
``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management <workflow.html>`_.
|
||||
|
||||
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Interday Model`` can be used as an independent module also.
|
||||
|
||||
@@ -20,151 +20,125 @@ The base class provides the following interfaces:
|
||||
|
||||
- `__init__(**kwargs)`
|
||||
- Initialization.
|
||||
- If users use ``Estimator`` to start an `experiment`, the parameter of `__init__` method shoule be consistent with the hyperparameters in the configuration file.
|
||||
|
||||
- `fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs)`
|
||||
- `fit(self, dataset, **kwargs)`
|
||||
- Train model.
|
||||
- Parameter:
|
||||
- `x_train`, pd.DataFrame type, train feature
|
||||
The following example explains the value of `x_train`:
|
||||
- `dataset`, ``Qlib``'s ``DatasetH`` type. For more information about ``DatasetH``, users can refer to the related document: `Qlib Dataset <../component/data.html#dataset>`_.
|
||||
The `dataset` is passed into the `model`'s method because there are some unique data preprocessing procedures for each, we want to give each model maximum flexibility to handle the data that is suitable for their own.
|
||||
The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`:
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
KMID KLEN KMID2 KUP KUP2
|
||||
instrument datetime
|
||||
SH600004 2012-01-04 0.000000 0.017685 0.000000 0.012862 0.727275
|
||||
2012-01-05 -0.006473 0.025890 -0.250001 0.012945 0.499998
|
||||
2012-01-06 0.008117 0.019481 0.416666 0.008117 0.416666
|
||||
2012-01-09 0.016051 0.025682 0.624998 0.006421 0.250001
|
||||
2012-01-10 0.017323 0.026772 0.647057 0.003150 0.117648
|
||||
... ... ... ... ... ...
|
||||
SZ300273 2014-12-25 -0.005295 0.038697 -0.136843 0.016293 0.421052
|
||||
2014-12-26 -0.022486 0.041701 -0.539215 0.002453 0.058824
|
||||
2014-12-29 -0.031526 0.039092 -0.806451 0.000000 0.000000
|
||||
2014-12-30 -0.010000 0.032174 -0.310811 0.013913 0.432433
|
||||
2014-12-31 0.010917 0.020087 0.543479 0.001310 0.065216
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
`x_train` is a pandas DataFrame, whose index is MultiIndex <instrument(str), datetime(pd.Timestamp)>. Each column of `x_train` corresponds to a feature, and the column name is the feature name.
|
||||
|
||||
.. note::
|
||||
|
||||
The number and names of the columns are determined by the data handler, please refer to `Data Handler <data.html#data-handler>`_ and `Estimator Data Section <estimator.html#data-section>`_.
|
||||
|
||||
- `y_train`, pd.DataFrame type, train label
|
||||
The following example explains the value of `y_train`:
|
||||
# get features and labels
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
.. code-block:: YAML
|
||||
|
||||
LABEL
|
||||
instrument datetime
|
||||
SH600004 2012-01-04 -0.798456
|
||||
2012-01-05 -1.366716
|
||||
2012-01-06 -0.491026
|
||||
2012-01-09 0.296900
|
||||
2012-01-10 0.501426
|
||||
... ...
|
||||
SZ300273 2014-12-25 -0.465540
|
||||
2014-12-26 0.233864
|
||||
2014-12-29 0.471368
|
||||
2014-12-30 0.411914
|
||||
2014-12-31 1.342723
|
||||
|
||||
`y_train` is a pandas DataFrame, whose index is MultiIndex <instrument(str), datetime(pd.Timestamp)>. The `LABEL` column represents the value of train label.
|
||||
|
||||
.. note::
|
||||
|
||||
The number and names of the columns are determined by the ``Data Handler``, please refer to `Data Handler <data.html#data-handler>`_.
|
||||
|
||||
- `x_valid`, pd.DataFrame type, validation feature
|
||||
The format of `x_valid` is same as `x_train`
|
||||
|
||||
|
||||
- `y_valid`, pd.DataFrame type, validation label
|
||||
The format of `y_valid` is same as `y_train`
|
||||
|
||||
- `w_train`(Optional args, default is None), pd.DataFrame type, train weight
|
||||
`w_train` is a pandas DataFrame, whose shape and index is same as `x_train`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
|
||||
|
||||
- `w_train`(Optional args, default is None), pd.DataFrame type, validation weight
|
||||
`w_train` is a pandas DataFrame, whose shape and index is the same as `x_valid`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
|
||||
|
||||
- `predict(self, x_test, **kwargs)`
|
||||
- Predict test data 'x_test'
|
||||
- Parameter:
|
||||
- `x_test`, pd.DataFrame type, test features
|
||||
The form of `x_test` is same as `x_train` in 'fit' method.
|
||||
- Return:
|
||||
- `label`, np.ndarray type, test label
|
||||
The label of `x_test` that predicted by model.
|
||||
|
||||
- `score(self, x_test, y_test, w_test=None, **kwargs)`
|
||||
- Evaluate model with test feature/label
|
||||
- Parameter:
|
||||
- `x_test`, pd.DataFrame type, test feature
|
||||
The format of `x_test` is same as `x_train` in `fit` method.
|
||||
# get weights
|
||||
try:
|
||||
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
|
||||
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
|
||||
except KeyError as e:
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
- `x_test`, pd.DataFrame type, test label
|
||||
The format of `y_test` is same as `y_train` in `fit` method.
|
||||
- `predict(self, dataset, **kwargs)`
|
||||
- Predict test data.
|
||||
- Parameter:
|
||||
- `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
|
||||
- Returns:
|
||||
- Predic results with type: `pandas.Series`.
|
||||
|
||||
- `w_test`, pd.DataFrame type, test weight
|
||||
The format of `w_test` is same as `w_train` in `fit` method.
|
||||
- Return: float type, evaluation score
|
||||
- `finetune(self, dataset, **kwargs)`
|
||||
- Finetune the model.
|
||||
- Parameter:
|
||||
- `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
|
||||
|
||||
For other interfaces such as `save`, `load`, `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
|
||||
|
||||
For other interfaces such as `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
|
||||
|
||||
Example
|
||||
==================
|
||||
|
||||
``Qlib`` provides ``LightGBM`` and ``DNN`` models as the baseline, the following steps show how to run`` LightGBM`` as an independent module.
|
||||
``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``DNN``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. The following steps show how to run`` LightGBM`` as an independent module.
|
||||
|
||||
- Initialize ``Qlib`` with `qlib.init` first, please refer to `Initialization <../start/initialization.html>`_.
|
||||
- Run the following code to get the `prediction score` `pred_score`
|
||||
.. code-block:: Python
|
||||
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"dropna_label": True,
|
||||
"start_date": "2007-01-01",
|
||||
"end_date": "2020-08-01",
|
||||
"market": MARKET,
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_date": "2007-01-01",
|
||||
"train_end_date": "2014-12-31",
|
||||
"validate_start_date": "2015-01-01",
|
||||
"validate_end_date": "2016-12-31",
|
||||
"test_start_date": "2017-01-01",
|
||||
"test_end_date": "2020-08-01",
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(
|
||||
**DATA_HANDLER_CONFIG
|
||||
).get_split_data(**TRAINER_CONFIG)
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
# train
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
MODEL_CONFIG = {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
}
|
||||
# use default model
|
||||
model = LGBModel(**MODEL_CONFIG)
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
_pred = model.predict(x_test)
|
||||
pred_score = pd.DataFrame(index=_pred.index)
|
||||
pred_score["score"] = _pred.iloc(axis=1)[0]
|
||||
|
||||
.. note:: `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.
|
||||
.. note::
|
||||
|
||||
`Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler <data.html#data-handler>`_.
|
||||
`SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow <recorder.html#record-template>`_.
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
|
||||
|
||||
@@ -50,312 +50,17 @@ Qlib Recorder
|
||||
|
||||
Here are the available interfaces of ``QlibRecorder``:
|
||||
|
||||
- `__init__(exp_manager)`
|
||||
- Initialization.
|
||||
- It takes in an input: `exp_manager`, which is an `ExperimentManager` instance. The instance will be created during ``qlib.init``.
|
||||
|
||||
- `start(experiment_name=None, recorder_name=None)`
|
||||
- High level API to start an experiment. This method can only be called within a Python's '`with`' statement.
|
||||
- Parameters:
|
||||
- `experiment_name` : str
|
||||
name of the experiment one wants to start.
|
||||
- `recorder_name` : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
with R.start('test', 'recorder_1'):
|
||||
model.fit(dataset)
|
||||
R.log...
|
||||
... # further operations
|
||||
|
||||
- `start_exp(experiment_name=None, recorder_name=None, uri=None)`
|
||||
- Lower level method for starting an experiment. When use this method, one should end the experiment manually and the status of the recorder may not be handled properly.
|
||||
- Parameters:
|
||||
- `experiment_name` : str
|
||||
the name of the experiment to be started
|
||||
- `recorder_name` : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
- `uri` : str
|
||||
the tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
|
||||
The default uri are set in the qlib.config.
|
||||
- Returns:
|
||||
- an experiment instance being started.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
R.start_exp(experiment_name='test', recorder_name='recorder_1')
|
||||
... # further operations
|
||||
R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
|
||||
|
||||
- `end_exp(recorder_status=Recorder.STATUS_FI)`
|
||||
- Method for ending an experiment manually. It will end the current active experiment, as well as its active recorder with the specified `status` type.
|
||||
- Parameters:
|
||||
- `status` : str
|
||||
The status of a recorder, which can be '`SCHEDULED`', '`RUNNING`', '`FINISHED`', '`FAILED`'.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
R.start_exp(experiment_name='test')
|
||||
... # further operations
|
||||
R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
|
||||
|
||||
- `search_records(experiment_ids, **kwargs)`
|
||||
- Get a pandas DataFrame of all the records that have been stored with the given search criteria. This method is highly correlated with MLFlow's ``search_runs`` method (`link <https://www.mlflow.org/docs/latest/python_api/mlflow.html#mlflow.search_runs>`_).
|
||||
- Parameters:
|
||||
- `experiment_ids` : list
|
||||
list of experiment IDs.
|
||||
- `filter_string` : str
|
||||
filter query string, defaults to searching all runs.
|
||||
- `run_view_type` : int
|
||||
one of enum values ACTIVE_ONLY (1), DELETED_ONLY (2), or ALL (3).
|
||||
- `max_results` : int
|
||||
the maximum number of runs to put in the dataframe.
|
||||
- `order_by` : list
|
||||
list of columns to order by (e.g., “metrics.rmse”).
|
||||
- Returns:
|
||||
- A pandas.DataFrame of records, where each metric, parameter, and tag are expanded into their own columns named metrics.*, params.*, and tags.* respectively. For records that don't have a particular metric, parameter, or tag, their value will be (NumPy) Nan, None, or None respectively.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
R.log_metrics(m=2.50, step=0)
|
||||
records = R.search_runs([experiment_id], order_by=["metrics.m DESC"])
|
||||
|
||||
- `list_experiments()`
|
||||
- Method for listing all the existing experiments (except for those being deleted.)
|
||||
- Returns:
|
||||
- A dictionary (name -> experiment) of experiments information that being stored.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
exps = R.list_experiments()
|
||||
|
||||
- `list_recorders(experiment_id=None, experiment_name=None)`
|
||||
- Method for listing all the recorders of experiment with given id or name. If user doesn't provide the id or name of the experiment, this method will try to retrieve the default experiment and list all the recorders of the default experiment. If the default experiment doesn't exist, the method will first create the default experiment, and then create a new recorder under it.
|
||||
- Parameters:
|
||||
- `experiment_id` : str
|
||||
id of the experiment.
|
||||
- `experiment_name` : str
|
||||
name of the experiment.
|
||||
- Returns:
|
||||
- A dictionary (id -> recorder) of recorder information that being stored.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
recorders = R.list_recorders(experiment_name='test')
|
||||
|
||||
- `get_exp(experiment_id=None, experiment_name=None, create: bool = True)`
|
||||
- Method for retrieving an experiment with given id or name. Once the '`create`' argument is set to True, if no valid experiment is found, this method will create one for the user. Otherwise, it will only retrieve a specific experiment or raise an Error.
|
||||
|
||||
- If '`create`' is True:
|
||||
- If ``R``'s running:
|
||||
- no id or name specified, return the active experiment.
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be running.
|
||||
- If ``R``'s not running:
|
||||
- no id or name specified, create a default experiment, and the experiment is set to be running.
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment, and the experiment is set to be running.
|
||||
- Else If '`create`' is False:
|
||||
- If ``R``'s running:
|
||||
- no id or name specified, return the active experiment.
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, raise Error.
|
||||
- If ``R``'s not running:
|
||||
- no id or name specified. If the default experiment exists, return it, otherwise, raise Error.
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, raise Error.
|
||||
- Parameters:
|
||||
- `experiment_id` : str
|
||||
id of the experiment.
|
||||
- `experiment_name` : str
|
||||
name of the experiment.
|
||||
- `create` : boolean
|
||||
an argument determines whether the method will automatically create a new experiment according to user's specification if the experiment hasn't been created before.
|
||||
- Returns:
|
||||
- An experiment instance with given id or name.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
exp = R.get_exp()
|
||||
recorders = exp.list_recorders()
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
exp = R.get_exp('test1')
|
||||
|
||||
# Case 3
|
||||
exp = R.get_exp() -> a default experiment.
|
||||
|
||||
# Case 4
|
||||
exp = R.get_exp(experiment_name='test')
|
||||
|
||||
# Case 5
|
||||
exp = R.get_exp(create=False) -> the default experiment if exists.
|
||||
|
||||
- `delete_exp(experiment_id=None, experiment_name=None)`
|
||||
- Method for deleting the experiment with given id or name. At least one of id or name must be given, otherwise, error will occur.
|
||||
- Parameters:
|
||||
- `experiment_id` : str
|
||||
id of the experiment.
|
||||
- `experiment_name` : str
|
||||
name of the experiment.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
R.delete_exp(experiment_name='test')
|
||||
|
||||
- `get_uri()`
|
||||
- Method for retrieving the uri of current experiment manager.
|
||||
- Returns:
|
||||
- The uri of current experiment manager.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
uri = R.get_uri()
|
||||
|
||||
- `get_recorder(recorder_id=None, recorder_name=None, experiment_name=None)`
|
||||
- Method for retrieving a recorder. The recorder can be used for further process such as ``save_objects``, ``load_object``, ``log_params``, ``log_metrics``, etc.
|
||||
|
||||
- If ``R``'s running:
|
||||
- no id or name specified, return the active recorder.
|
||||
- if id or name is specified, return the specified recorder.
|
||||
- If ``R``'s not running:
|
||||
- no id or name specified, raise Error.
|
||||
- if id or name is specified, and the corresponding experiment_name must be given, return the specified recorder. Otherwise, raise Error.
|
||||
- Parameters:
|
||||
- `recorder_id` : str
|
||||
id of the recorder.
|
||||
- `recorder_name` : str
|
||||
name of the recorder.
|
||||
- `experiment_name` : str
|
||||
name of the experiment.
|
||||
- Returns:
|
||||
- A recorder instance.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
recorder = R.get_recorder()
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
|
||||
|
||||
# Case 3
|
||||
recorder = R.get_recorder() -> Error
|
||||
|
||||
# Case 4
|
||||
recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d') -> Error
|
||||
|
||||
# Case 5
|
||||
recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')
|
||||
|
||||
- `delete_recorder(recorder_id=None, recorder_name=None)`
|
||||
- Method for deleting the recorders with given id or name. At least one of id or name must be given, otherwise, error will occur.
|
||||
- Parameters:
|
||||
- `recorder_id` : str
|
||||
id of the experiment.
|
||||
- `recorder_name` : str
|
||||
name of the experiment.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
R.delete_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
|
||||
|
||||
- `save_objects(local_path=None, artifact_path=None, **kwargs)`
|
||||
- Method for saving objects as artifacts in the experiment to the uri. It supports either saving from a local file/directory, or directly saving objects. User can use valid python's keywords arguments to specify the object to be saved as well as its name (name: value).
|
||||
|
||||
- If R's running: it will save the objects through the running recorder.
|
||||
- If R's not running: the system will create a default experiment, and a new recorder and save objects under it.
|
||||
|
||||
.. note::
|
||||
|
||||
If one wants to save objects with a specific recorder. It is recommended to first get the specific recorder through `get_recorder` API and use the recorder the save objects. The supported arguments are the same as this method.
|
||||
|
||||
- Parameters:
|
||||
- `local_path` : str
|
||||
if provided, them save the file or directory to the artifact URI.
|
||||
- `artifact_path` : str
|
||||
the relative path for the artifact to be stored in the URI.
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
pred = model.predict(dataset)
|
||||
R.save_objects(**{"pred.pkl": pred}, artifact_path='prediction')
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
R.save_objects(local_path='results/pred.pkl')
|
||||
|
||||
- `log_params(**kwargs)`
|
||||
- Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
|
||||
|
||||
- If R's running: it will log parameters through the running recorder.
|
||||
- If R's not running: the system will create a default experiment as well as a new recorder, and log parameters under it.
|
||||
- Parameters:
|
||||
- `keyword argument`:
|
||||
name1=value1, name2=value2, ...
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
R.log_params(learning_rate=0.01)
|
||||
|
||||
# Case 2
|
||||
R.log_params(learning_rate=0.01)
|
||||
|
||||
- `log_metrics(step=None, **kwargs)`
|
||||
- Method for logging metrics during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
|
||||
|
||||
- If R's running: it will log metrics through the running recorder.
|
||||
- If R's not running: the system will create a default experiment as well as a new recorder, and log metrics under it.
|
||||
- Parameters:
|
||||
- `step`: int
|
||||
a single integer step at which to log the specified Metrics. If unspecified, each metric is logged at step zero.
|
||||
- `keyword argument`:
|
||||
name1=value1, name2=value2, ...
|
||||
|
||||
- `set_tags(**kwargs)`
|
||||
- Method for setting tags for a recorder. In addition to using ``R``, one can also set the tag to a specific recorder after getting it with `get_recorder` API.
|
||||
|
||||
- If R's running: it will set tags through the running recorder.
|
||||
- If R's not running: the system will create a default experiment as well as a new recorder, and set the tags under it.
|
||||
- Parameters:
|
||||
- `keyword argument`:
|
||||
name1=value1, name2=value2, ...
|
||||
- Use case:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# Case 1
|
||||
with R.start('test'):
|
||||
R.set_tags(release_version="2.2.0")
|
||||
|
||||
# Case 2
|
||||
R.set_tags(release_version="2.2.0")
|
||||
|
||||
.. autoclass:: qlib.workflow.__init__.QlibRecorder
|
||||
:members:
|
||||
|
||||
Experiment Manager
|
||||
===================
|
||||
|
||||
The ``ExpManager`` module in ``Qlib`` is responsible for managing different experiments. Most of the APIs of ``ExpManager`` are similar to ``QlibRecorder``, and the most important API will be the ``get_exp`` method. User can directly refer to the documents above for some detailed information about how to use the ``get_exp`` method.
|
||||
|
||||
.. autoclass:: qlib.workflow.expm.ExpManager
|
||||
:members: get_exp, list_experiments
|
||||
|
||||
For other interfaces such as `create_exp`, `delete_exp`, please refer to `Experiment Manager API <../reference/api.html#experiment-manager>`_.
|
||||
|
||||
Experiment
|
||||
@@ -363,6 +68,9 @@ Experiment
|
||||
|
||||
The ``Experiment`` class is solely responsible for a single experiment, and it will handle any operations that are related to an experiment. Basic methods such as `start`, `end` an experiment are included. Besides, methods related to `recorders` are also available: such methods include `get_recorder` and `list_recorders`.
|
||||
|
||||
.. autoclass:: qlib.workflow.exp.Experiment
|
||||
:members: get_recorder, list_recorders
|
||||
|
||||
For other interfaces such as `search_records`, `delete_recorder`, please refer to `Experiment API <../reference/api.html#experiment>`_.
|
||||
|
||||
Recorder
|
||||
@@ -372,28 +80,8 @@ The ``Recorder`` class is responsible for a single recorder. It will handle some
|
||||
|
||||
Here are some important APIs that are not included in the ``QlibRecorder``:
|
||||
|
||||
- `list_artifacts(artifact_path: str = None)`
|
||||
- List all the artifacts of a recorder.
|
||||
- Parameters:
|
||||
- `artifact_path` : str
|
||||
the relative path for the artifact to be stored in the URI.
|
||||
- Returns:
|
||||
- A list of artifacts information (name, path, etc.) that being stored.
|
||||
|
||||
- `list_metrics()`
|
||||
- List all the metrics of a recorder.
|
||||
- Returns:
|
||||
- A dictionary of metrics that being stored.
|
||||
|
||||
- `list_params()`
|
||||
- List all the params of a recorder.
|
||||
- Returns:
|
||||
- A dictionary of params that being stored.
|
||||
|
||||
- `list_tags()`
|
||||
- List all the tags of a recorder.
|
||||
- Returns:
|
||||
- A dictionary of tags that being stored.
|
||||
.. autoclass:: qlib.workflow.recorder.Recorder
|
||||
:members: list_artifacts, list_metrics, list_params, list_tags
|
||||
|
||||
For other interfaces such as `save_objects`, `load_object`, please refer to `Recorder API <../reference/api.html#recorder>`_.
|
||||
|
||||
@@ -402,8 +90,8 @@ Record Template
|
||||
|
||||
The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
|
||||
|
||||
- ``SignalRecord``: This class generates the `preidction` of the model.
|
||||
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR`.
|
||||
- ``SignalRecord``: This class generates the `preidction` results of the model.
|
||||
- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
|
||||
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
|
||||
|
||||
For more information, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
|
||||
@@ -1,13 +1,13 @@
|
||||
.. _report:
|
||||
|
||||
==========================================
|
||||
Aanalysis: Evaluation & Results Analysis
|
||||
Analysis: Evaluation & Results Analysis
|
||||
==========================================
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
``Aanalysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
|
||||
``Analysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
|
||||
|
||||
- analysis_position
|
||||
- report_graph
|
||||
|
||||
@@ -124,7 +124,7 @@ html_theme_options = {
|
||||
"logo_only": True,
|
||||
"collapse_navigation": False,
|
||||
"display_version": False,
|
||||
"navigation_depth": 3,
|
||||
"navigation_depth": 4,
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
|
||||
@@ -41,7 +41,7 @@ Document Structure
|
||||
Interday Strategy: Portfolio Management <component/strategy.rst>
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Aanalysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -21,27 +21,27 @@ Framework
|
||||
|
||||
At the module level, Qlib is a platform that consists of above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
|
||||
====================== ==============================================================================
|
||||
Name Description
|
||||
====================== ==============================================================================
|
||||
`Data layer` `DataServer` focuses on providing high-performance infrastructure for users to
|
||||
manage and retrieve raw data. `DataEnhancement` will preprocess the data and
|
||||
provide the best dataset to be fed into the models.
|
||||
|
||||
`Interday Model` `Interday model` focuses on producing prediction scores (aka. `alpha`). Models
|
||||
are trained by `Model Creator` and managed by `Model Manager`. Users could
|
||||
choose one or multiple models for prediction. Multiple models could be combined
|
||||
with `Ensemble` module.
|
||||
|
||||
`Interday Strategy` `Portfolio Generator` will take prediction scores as input and output the
|
||||
orders based on the current position to achieve the target portfolio.
|
||||
======================== ==============================================================================
|
||||
Name Description
|
||||
======================== ==============================================================================
|
||||
`Infrastructure` layer `Infrastructure` layer provides underlying support for Quant research.
|
||||
`DataServer` provides high-performance infrastructure for users to manage
|
||||
and retrieve raw data. `Trainer` provides flexible interface to control
|
||||
the training process of models which enable algorithms controlling the
|
||||
training process.
|
||||
|
||||
`Intraday Trading` `Order Executor` is responsible for executing orders output by
|
||||
`Interday Strategy` and returning the executed results.
|
||||
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
|
||||
`Information Extractor` extracts data for models. `Forecast Model` focuses
|
||||
on producing all kinds of forecast signals (e.g. _alpha_, risk) for other
|
||||
modules. With these signals `Portfolio Generator` will generate the target
|
||||
portfolio and produce orders to be executed by `Order Executor`.
|
||||
|
||||
`Analysis` Users could get a detailed analysis report of forecasting signals and portfolios
|
||||
in this part.
|
||||
====================== ==============================================================================
|
||||
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
|
||||
system. `Analyser` module will provide users detailed analysis reports of
|
||||
forecasting signals, portfolios and execution results
|
||||
======================== ==============================================================================
|
||||
|
||||
- The modules with hand-drawn style are under development and will be released in the future.
|
||||
- The modules with dashed borders are highly user-customizable and extendible.
|
||||
|
||||
@@ -84,7 +84,7 @@ Auto Quant Research Workflow
|
||||
- Run ``examples/workflow_by_code.ipynb`` with jupyter notebook
|
||||
Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/workflow_by_code.ipynb``.
|
||||
- Graphical Reports
|
||||
Users can get graphical reports about the analysis, please refer to `Aanalysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
|
||||
Users can get graphical reports about the analysis, please refer to `Analysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -23,16 +23,13 @@ Filter
|
||||
.. automodule:: qlib.data.filter
|
||||
:members:
|
||||
|
||||
Feature
|
||||
--------------------
|
||||
|
||||
Class
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
--------------------
|
||||
.. automodule:: qlib.data.base
|
||||
:members:
|
||||
|
||||
Operator
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
--------------------
|
||||
.. automodule:: qlib.data.ops
|
||||
:members:
|
||||
|
||||
@@ -56,16 +53,33 @@ Cache
|
||||
.. autoclass:: qlib.data.cache.DiskDatasetCache
|
||||
:members:
|
||||
|
||||
Dataset
|
||||
---------------
|
||||
|
||||
Dataset Class
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.__init__
|
||||
:members:
|
||||
|
||||
Data Loader
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.loader
|
||||
:members:
|
||||
|
||||
Data Handler
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.handler
|
||||
:members:
|
||||
|
||||
Processor
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: qlib.data.dataset.processor
|
||||
:members:
|
||||
|
||||
|
||||
Contrib
|
||||
====================
|
||||
|
||||
|
||||
Data Handler
|
||||
---------------
|
||||
.. automodule:: qlib.data.dataset.handler
|
||||
:members:
|
||||
|
||||
Model
|
||||
--------------------
|
||||
.. automodule:: qlib.model.base
|
||||
|
||||
@@ -12,14 +12,16 @@ Initialization
|
||||
|
||||
Please follow the steps below to initialize ``Qlib``.
|
||||
|
||||
- Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>` for more information about customized dataset.
|
||||
Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance <https://finance.yahoo.com/lookup>`_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>`_ for more information about customized dataset.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
|
||||
|
||||
Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
|
||||
|
||||
|
||||
- Initialize Qlib before calling other APIs: run following code in python.
|
||||
Initialize Qlib before calling other APIs: run following code in python.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ Custom Model Integration
|
||||
Introduction
|
||||
===================
|
||||
|
||||
``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``.
|
||||
``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``DNN``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. In addition to the default models ``Qlib`` provide, users can integrate their own custom models into ``Qlib``.
|
||||
|
||||
Users can integrate their own custom models according to the following steps.
|
||||
|
||||
@@ -32,79 +32,76 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
|
||||
- Override the `fit` method
|
||||
- ``Qlib`` calls the fit method to train the model
|
||||
- The parameters must include training feature `x_train`, training label `y_train`, test feature `x_valid`, test label `y_valid` at least.
|
||||
- The parameters could include some optional parameters with default values, such as train weight `w_train`, test weight `w_valid` and `num_boost_round = 1000`.
|
||||
- The parameters must include training feature `dataset`.
|
||||
- The parameters could include some optional parameters with default values, such as `num_boost_round = 1000` for `GBDT`.
|
||||
- Code Example: In the following example, `num_boost_round = 1000` is an optional parameter.
|
||||
.. code-block:: Python
|
||||
|
||||
def fit(self, x_train:pd.DataFrame, y_train:pd.DataFrame, x_valid:pd.DataFrame, y_valid:pd.DataFrame,
|
||||
w_train:pd.DataFrame = None, w_valid:pd.DataFrame = None, num_boost_round = 1000, **kwargs):
|
||||
def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
|
||||
|
||||
# prepare dataset for lgb training and evaluation
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError('LightGBM doesn\'t support multi-label training')
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
w_train_weight = None if w_train is None else w_train.values
|
||||
w_valid_weight = None if w_valid is None else w_valid.values
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
|
||||
self._model = lgb.train(
|
||||
self._params,
|
||||
dtrain,
|
||||
# fit the model
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=['train', 'valid'],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
- Override the `predict` method
|
||||
- The parameters include the test features.
|
||||
- The parameters must include training feature `dataset`, which will be userd to get the test dataset.
|
||||
- Return the `prediction score`.
|
||||
- Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
|
||||
- Code Example: In the following example, users need to use dnn to predict the label(such as `preds`) of test data `x_test` and return it.
|
||||
- Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
|
||||
.. code-block:: Python
|
||||
|
||||
def predict(self, x_test:pd.DataFrame, **kwargs)-> numpy.ndarray:
|
||||
if self._model is None:
|
||||
raise ValueError('model is not fitted yet!')
|
||||
return self._model.predict(x_test.values)
|
||||
def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
- Override the `save` method & `load` method
|
||||
- The `save` method parameter includes the a `filename` that represents an absolute path, user need to save model into the path.
|
||||
- The `load` method parameter includes the a `buffer` read from the `filename` passed in the `save` method, users need to load model from the `buffer`.
|
||||
- Code Example:
|
||||
- Override the `finetune` method
|
||||
- The parameters must include training feature `dataset`.
|
||||
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
|
||||
.. code-block:: Python
|
||||
|
||||
def save(self, filename):
|
||||
if self._model is None:
|
||||
raise ValueError('model is not fitted yet!')
|
||||
self._model.save_model(filename)
|
||||
|
||||
def load(self, buffer):
|
||||
self._model = lgb.Booster(params={'model_str': buffer.decode('utf-8')})
|
||||
|
||||
.. Without tuner, this part will not be used
|
||||
.. - Override the `score` method(This step is optional)
|
||||
.. - The parameters include the test features and test labels.
|
||||
.. - Return the evaluation score of the model. It's recommended to adopt the loss between labels and `prediction score`.
|
||||
.. - Code Example: In the following example, users need to calculate the weighted loss with test data `x_test`, test label `y_test` and the weight `w_test`.
|
||||
.. .. code-block:: Python
|
||||
..
|
||||
.. def score(self, x_test:pd.Dataframe, y_test:pd.Dataframe, w_test:pd.DataFrame = None) -> float:
|
||||
.. # Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
.. x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
|
||||
.. preds = self.predict(x_test)
|
||||
.. w_test_weight = None if w_test is None else w_test.values
|
||||
.. scorer = mean_squared_error if self.loss_type == 'mse' else roc_auc_score
|
||||
.. return scorer(y_test.values, preds, sample_weight=w_test_weight)
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
|
||||
Configuration File
|
||||
=======================
|
||||
|
||||
The configuration file is described in detail in the `estimator <../component/estimator.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file.
|
||||
The configuration file is described in detail in the `Workflow <../component/workflow.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file.
|
||||
|
||||
- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`.
|
||||
|
||||
@@ -124,20 +121,20 @@ The configuration file is described in detail in the `estimator <../component/es
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
|
||||
Users could find configuration file of the baseline of the ``Model`` in ``qlib/examples/estimator/estimator_config.yaml`` and ``qlib/examples/estimator/estimator_config_dnn.yaml``
|
||||
Users could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder.
|
||||
|
||||
Model Testing
|
||||
=====================
|
||||
Assuming that the configuration file is ``examples/estimator/estimator_config.yaml``, users can run the following command to test the custom model:
|
||||
Assuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml``, users can run the following command to test the custom model:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd examples # Avoid running program under the directory contains `qlib`
|
||||
estimator -c estimator/estimator_config.yaml
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
|
||||
|
||||
.. note:: ``estimator`` is a built-in command of ``Qlib``.
|
||||
.. note:: ``qrun`` is a built-in command of ``Qlib``.
|
||||
|
||||
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``.
|
||||
|
||||
|
||||
Reference
|
||||
|
||||
10
examples/benchmarks/ALSTM/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# ALSTM
|
||||
|
||||
- ALSTM contains a temporal attentive aggregation layer based on normal LSTM.
|
||||
|
||||
- The code used in Qlib is a pyTorch implementation of Code: https://github.com/fulifeng/Adv-ALSTM
|
||||
|
||||
- Paper: A dual-stage attention-based recurrent neural network for time series prediction.
|
||||
|
||||
https://www.ijcai.org/Proceedings/2017/0366.pdf
|
||||
|
||||
4
examples/benchmarks/ALSTM/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
83
examples/benchmarks/ALSTM/workflow_config_alstm.yaml
Normal file
@@ -0,0 +1,83 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: ALSTM
|
||||
module_path: qlib.contrib.model.pytorch_alstm
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
rnn_type: GRU
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -30,8 +30,13 @@ task:
|
||||
module_path: qlib.contrib.model.catboost_model
|
||||
kwargs:
|
||||
loss: RMSE
|
||||
iterations: 5
|
||||
learning_rate: 0.03
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
max_depth: 6
|
||||
num_leaves: 100
|
||||
thread_count: 20
|
||||
grow_policy: Lossguide
|
||||
bootstrap_type: Poisson
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
@@ -56,4 +61,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
5
examples/benchmarks/GATs/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# GATs
|
||||
* Graph Attention Networks(GATs) leverage masked self-attentional layers on graph-structured data. The nodes in stacked layers have different weights and they are able to attend over their
|
||||
neighborhoods’ features, without requiring any kind of costly matrix operation (such as inversion) or depending on knowing the graph structure upfront.
|
||||
* This code used in Qlib is implemented with PyTorch by ourselves.
|
||||
* Paper: Graph Attention Networks https://arxiv.org/pdf/1710.10903.pdf
|
||||
@@ -36,7 +36,6 @@ task:
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: LSTM
|
||||
|
||||
@@ -8,6 +8,20 @@ data_handler_config: &data_handler_config
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
@@ -37,7 +51,7 @@ task:
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: IC
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
@@ -46,7 +60,7 @@ task:
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360_Denoise
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
|
||||
15
examples/benchmarks/HATS/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
## Requirement
|
||||
|
||||
* pandas==1.1.2
|
||||
* numpy==1.17.4
|
||||
* scikit_learn==0.23.2
|
||||
* torch==1.7.0
|
||||
|
||||
## HATS
|
||||
|
||||
* HATS is a a hierarchical attention network for stock prediction which uses relational data for stock market prediction. HATS selectively aggregates information
|
||||
on different relation types and adds the information to the representations of each company. HATS is used as a relational modeling module with initialized node representations.Furthermore, HATS
|
||||
can predict not only individual stock prices but also market index movements, which is similar to the graph classification task.
|
||||
|
||||
* HATS uses pretrained model of GRU and LSTM. The code of GRU and LSTM used in Qlib is a pyTorch implemention of GRU and LSTM.
|
||||
* Paper address:HATS: A Hierarchical Graph Attention Network for Stock Movement Prediction https://arxiv.org/pdf/1908.07999.pdf
|
||||
4
examples/benchmarks/HATS/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
77
examples/benchmarks/HATS/worflow_config_hats.yaml
Normal file
@@ -0,0 +1,77 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: HATS
|
||||
module_path: qlib.contrib.model.pytorch_hats
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.6
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
metric: loss
|
||||
loss: mse
|
||||
base_model: GRU
|
||||
seed: 0
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -8,6 +8,20 @@ data_handler_config: &data_handler_config
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
@@ -37,7 +51,7 @@ task:
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: IC
|
||||
metric: loss
|
||||
loss: mse
|
||||
seed: 0
|
||||
GPU: 0
|
||||
@@ -46,7 +60,7 @@ task:
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360_Denoise
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
|
||||
4
examples/benchmarks/SFM/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# State-Frequency-Memory
|
||||
- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform (DFT) to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
|
||||
- The code used in Qlib is a pyTorch implementation of SFM (Zhang, L., Aggarwal, C., & Qi, G. J. (2017,)).
|
||||
- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.
|
||||
@@ -8,6 +8,20 @@ data_handler_config: &data_handler_config
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
@@ -31,27 +45,25 @@ task:
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
output_dim: 1
|
||||
freq_dim: 15
|
||||
output_dim: 32
|
||||
freq_dim: 25
|
||||
dropout_W: 0.5
|
||||
dropout_U: 0.5
|
||||
n_epochs: 10
|
||||
n_epochs: 20
|
||||
lr: 1e-3
|
||||
batch_size: 800
|
||||
batch_size: 1600
|
||||
early_stop: 20
|
||||
eval_steps: 5
|
||||
loss: mse
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
optimizer: gd
|
||||
optimizer: adam
|
||||
GPU: 1
|
||||
seed: 0
|
||||
seed: 710
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360_Denoise
|
||||
class: ALPHA360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
@@ -70,4 +82,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
14
examples/benchmarks/TFT/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Temporal Fusion Transformers Benchmark
|
||||
## Source
|
||||
**Reference**: Lim, Bryan, et al. "Temporal fusion transformers for interpretable multi-horizon time series forecasting." arXiv preprint arXiv:1912.09363 (2019).
|
||||
|
||||
**GitHub**: https://github.com/google-research/google-research/tree/master/tft
|
||||
|
||||
## Run the Workflow
|
||||
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
|
||||
|
||||
### Notes
|
||||
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
|
||||
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
|
||||
3. The model must run in GPU, or an error will be raised.
|
||||
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
|
||||
14
examples/benchmarks/TFT/data_formatters/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
223
examples/benchmarks/TFT/data_formatters/base.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Default data formatting functions for experiments.
|
||||
|
||||
For new datasets, inherit form GenericDataFormatter and implement
|
||||
all abstract functions.
|
||||
|
||||
These dataset-specific methods:
|
||||
1) Define the column and input types for tabular dataframes used by model
|
||||
2) Perform the necessary input feature engineering & normalisation steps
|
||||
3) Reverts the normalisation for predictions
|
||||
4) Are responsible for train, validation and test splits
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import abc
|
||||
import enum
|
||||
|
||||
|
||||
# Type defintions
|
||||
class DataTypes(enum.IntEnum):
|
||||
"""Defines numerical types of each column."""
|
||||
|
||||
REAL_VALUED = 0
|
||||
CATEGORICAL = 1
|
||||
DATE = 2
|
||||
|
||||
|
||||
class InputTypes(enum.IntEnum):
|
||||
"""Defines input types of each column."""
|
||||
|
||||
TARGET = 0
|
||||
OBSERVED_INPUT = 1
|
||||
KNOWN_INPUT = 2
|
||||
STATIC_INPUT = 3
|
||||
ID = 4 # Single column used as an entity identifier
|
||||
TIME = 5 # Single column exclusively used as a time index
|
||||
|
||||
|
||||
class GenericDataFormatter(abc.ABC):
|
||||
"""Abstract base class for all data formatters.
|
||||
|
||||
User can implement the abstract methods below to perform dataset-specific
|
||||
manipulations.
|
||||
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_scalers(self, df):
|
||||
"""Calibrates scalers using the data supplied."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transform_inputs(self, df):
|
||||
"""Performs feature transformation."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_predictions(self, df):
|
||||
"""Reverts any normalisation to give predictions in original scale."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def split_data(self, df):
|
||||
"""Performs the default train, validation and test splits."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _column_definition(self):
|
||||
"""Defines order, input type and data type of each column."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_fixed_params(self):
|
||||
"""Defines the fixed parameters used by the model for training.
|
||||
|
||||
Requires the following keys:
|
||||
'total_time_steps': Defines the total number of time steps used by TFT
|
||||
'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
|
||||
'num_epochs': Maximum number of epochs for training
|
||||
'early_stopping_patience': Early stopping param for keras
|
||||
'multiprocessing_workers': # of cpus for data processing
|
||||
|
||||
|
||||
Returns:
|
||||
A dictionary of fixed parameters, e.g.:
|
||||
|
||||
fixed_params = {
|
||||
'total_time_steps': 252 + 5,
|
||||
'num_encoder_steps': 252,
|
||||
'num_epochs': 100,
|
||||
'early_stopping_patience': 5,
|
||||
'multiprocessing_workers': 5,
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Shared functions across data-formatters
|
||||
@property
|
||||
def num_classes_per_cat_input(self):
|
||||
"""Returns number of categories per relevant input.
|
||||
|
||||
This is seqeuently required for keras embedding layers.
|
||||
"""
|
||||
return self._num_classes_per_cat_input
|
||||
|
||||
def get_num_samples_for_calibration(self):
|
||||
"""Gets the default number of training and validation samples.
|
||||
|
||||
Use to sub-sample the data for network calibration and a value of -1 uses
|
||||
all available samples.
|
||||
|
||||
Returns:
|
||||
Tuple of (training samples, validation samples)
|
||||
"""
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
# Sanity checks first.
|
||||
# Ensure only one ID and time column exist
|
||||
def _check_single_column(input_type):
|
||||
|
||||
length = len([tup for tup in column_definition if tup[2] == input_type])
|
||||
|
||||
if length != 1:
|
||||
raise ValueError("Illegal number of inputs ({}) of type {}".format(length, input_type))
|
||||
|
||||
_check_single_column(InputTypes.ID)
|
||||
_check_single_column(InputTypes.TIME)
|
||||
|
||||
identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
|
||||
time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
|
||||
real_inputs = [
|
||||
tup
|
||||
for tup in column_definition
|
||||
if tup[1] == DataTypes.REAL_VALUED and tup[2] not in {InputTypes.ID, InputTypes.TIME}
|
||||
]
|
||||
categorical_inputs = [
|
||||
tup
|
||||
for tup in column_definition
|
||||
if tup[1] == DataTypes.CATEGORICAL and tup[2] not in {InputTypes.ID, InputTypes.TIME}
|
||||
]
|
||||
|
||||
return identifier + time + real_inputs + categorical_inputs
|
||||
|
||||
def _get_input_columns(self):
|
||||
"""Returns names of all input columns."""
|
||||
return [tup[0] for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
|
||||
|
||||
def _get_tft_input_indices(self):
|
||||
"""Returns the relevant indexes and input sizes required by TFT."""
|
||||
|
||||
# Functions
|
||||
def _extract_tuples_from_data_type(data_type, defn):
|
||||
return [tup for tup in defn if tup[1] == data_type and tup[2] not in {InputTypes.ID, InputTypes.TIME}]
|
||||
|
||||
def _get_locations(input_types, defn):
|
||||
return [i for i, tup in enumerate(defn) if tup[2] in input_types]
|
||||
|
||||
# Start extraction
|
||||
column_definition = [
|
||||
tup for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}
|
||||
]
|
||||
|
||||
categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, column_definition)
|
||||
real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, column_definition)
|
||||
|
||||
locations = {
|
||||
"input_size": len(self._get_input_columns()),
|
||||
"output_size": len(_get_locations({InputTypes.TARGET}, column_definition)),
|
||||
"category_counts": self.num_classes_per_cat_input,
|
||||
"input_obs_loc": _get_locations({InputTypes.TARGET}, column_definition),
|
||||
"static_input_loc": _get_locations({InputTypes.STATIC_INPUT}, column_definition),
|
||||
"known_regular_inputs": _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs),
|
||||
"known_categorical_inputs": _get_locations(
|
||||
{InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, categorical_inputs
|
||||
),
|
||||
}
|
||||
|
||||
return locations
|
||||
|
||||
def get_experiment_params(self):
|
||||
"""Returns fixed model parameters for experiments."""
|
||||
|
||||
required_keys = [
|
||||
"total_time_steps",
|
||||
"num_encoder_steps",
|
||||
"num_epochs",
|
||||
"early_stopping_patience",
|
||||
"multiprocessing_workers",
|
||||
]
|
||||
|
||||
fixed_params = self.get_fixed_params()
|
||||
|
||||
for k in required_keys:
|
||||
if k not in fixed_params:
|
||||
raise ValueError("Field {}".format(k) + " missing from fixed parameter definitions!")
|
||||
|
||||
fixed_params["column_definition"] = self.get_column_definition()
|
||||
|
||||
fixed_params.update(self._get_tft_input_indices())
|
||||
|
||||
return fixed_params
|
||||
219
examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Custom formatting functions for Alpha158 dataset.
|
||||
|
||||
Defines dataset specific column definitions and data transformations.
|
||||
"""
|
||||
|
||||
import data_formatters.base
|
||||
import libs.utils as utils
|
||||
import sklearn.preprocessing
|
||||
|
||||
GenericDataFormatter = data_formatters.base.GenericDataFormatter
|
||||
DataTypes = data_formatters.base.DataTypes
|
||||
InputTypes = data_formatters.base.InputTypes
|
||||
|
||||
|
||||
class Alpha158Formatter(GenericDataFormatter):
|
||||
"""Defines and formats data for the Alpha158 dataset.
|
||||
|
||||
Attributes:
|
||||
column_definition: Defines input and data type of column used in the
|
||||
experiment.
|
||||
identifiers: Entity identifiers used in experiments.
|
||||
"""
|
||||
|
||||
_column_definition = [
|
||||
("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
|
||||
("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
|
||||
("date", DataTypes.DATE, InputTypes.TIME),
|
||||
("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
|
||||
("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
|
||||
# Selected 10 features
|
||||
("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
|
||||
("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialises formatter."""
|
||||
|
||||
self.identifiers = None
|
||||
self._real_scalers = None
|
||||
self._cat_scalers = None
|
||||
self._target_scaler = None
|
||||
self._num_classes_per_cat_input = None
|
||||
|
||||
def split_data(self, df, valid_boundary=2016, test_boundary=2018):
|
||||
"""Splits data frame into training-validation-test data frames.
|
||||
|
||||
This also calibrates scaling object, and transforms data for each split.
|
||||
|
||||
Args:
|
||||
df: Source data frame to split.
|
||||
valid_boundary: Starting year for validation data
|
||||
test_boundary: Starting year for test data
|
||||
|
||||
Returns:
|
||||
Tuple of transformed (train, valid, test) data.
|
||||
"""
|
||||
|
||||
print("Formatting train-valid-test splits.")
|
||||
|
||||
index = df["year"]
|
||||
train = df.loc[index < valid_boundary]
|
||||
valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
|
||||
test = df.loc[index >= test_boundary]
|
||||
|
||||
self.set_scalers(train)
|
||||
|
||||
return (self.transform_inputs(data) for data in [train, valid, test])
|
||||
|
||||
def set_scalers(self, df):
|
||||
"""Calibrates scalers using the data supplied.
|
||||
|
||||
Args:
|
||||
df: Data to use to calibrate scalers.
|
||||
"""
|
||||
print("Setting scalers with training data...")
|
||||
|
||||
column_definitions = self.get_column_definition()
|
||||
id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
|
||||
target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
|
||||
|
||||
# Extract identifiers in case required
|
||||
self.identifiers = list(df[id_column].unique())
|
||||
|
||||
# Format real scalers
|
||||
real_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
data = df[real_inputs].values
|
||||
self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
|
||||
self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
|
||||
df[[target_column]].values
|
||||
) # used for predictions
|
||||
|
||||
# Format categorical scalers
|
||||
categorical_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
categorical_scalers = {}
|
||||
num_classes = []
|
||||
for col in categorical_inputs:
|
||||
# Set all to str so that we don't have mixed integer/string columns
|
||||
srs = df[col].apply(str)
|
||||
categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
|
||||
num_classes.append(srs.nunique())
|
||||
|
||||
# Set categorical scaler outputs
|
||||
self._cat_scalers = categorical_scalers
|
||||
self._num_classes_per_cat_input = num_classes
|
||||
|
||||
def transform_inputs(self, df):
|
||||
"""Performs feature transformations.
|
||||
|
||||
This includes both feature engineering, preprocessing and normalisation.
|
||||
|
||||
Args:
|
||||
df: Data frame to transform.
|
||||
|
||||
Returns:
|
||||
Transformed data frame.
|
||||
|
||||
"""
|
||||
output = df.copy()
|
||||
|
||||
if self._real_scalers is None and self._cat_scalers is None:
|
||||
raise ValueError("Scalers have not been set!")
|
||||
|
||||
column_definitions = self.get_column_definition()
|
||||
|
||||
real_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
categorical_inputs = utils.extract_cols_from_data_type(
|
||||
DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
|
||||
)
|
||||
|
||||
# Format real inputs
|
||||
output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
|
||||
|
||||
# Format categorical inputs
|
||||
for col in categorical_inputs:
|
||||
string_df = df[col].apply(str)
|
||||
output[col] = self._cat_scalers[col].transform(string_df)
|
||||
|
||||
return output
|
||||
|
||||
def format_predictions(self, predictions):
|
||||
"""Reverts any normalisation to give predictions in original scale.
|
||||
|
||||
Args:
|
||||
predictions: Dataframe of model predictions.
|
||||
|
||||
Returns:
|
||||
Data frame of unnormalised predictions.
|
||||
"""
|
||||
output = predictions.copy()
|
||||
|
||||
column_names = predictions.columns
|
||||
|
||||
for col in column_names:
|
||||
if col not in {"forecast_time", "identifier"}:
|
||||
output[col] = self._target_scaler.inverse_transform(predictions[col])
|
||||
|
||||
return output
|
||||
|
||||
# Default params
|
||||
def get_fixed_params(self):
|
||||
"""Returns fixed model parameters for experiments."""
|
||||
|
||||
fixed_params = {
|
||||
"total_time_steps": 16 + 6,
|
||||
"num_encoder_steps": 16,
|
||||
"num_epochs": 100,
|
||||
"early_stopping_patience": 5,
|
||||
"multiprocessing_workers": 5,
|
||||
}
|
||||
|
||||
return fixed_params
|
||||
|
||||
def get_default_model_params(self):
|
||||
"""Returns default optimised model parameters."""
|
||||
|
||||
model_params = {
|
||||
"dropout_rate": 0.3,
|
||||
"hidden_layer_size": 160,
|
||||
"learning_rate": 0.001,
|
||||
"minibatch_size": 64,
|
||||
"max_gradient_norm": 0.01,
|
||||
"num_heads": 1,
|
||||
"stack_size": 1,
|
||||
}
|
||||
|
||||
return model_params
|
||||
14
examples/benchmarks/TFT/expt_settings/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
95
examples/benchmarks/TFT/expt_settings/configs.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Default configs for TFT experiments.
|
||||
|
||||
Contains the default output paths for data, serialised models and predictions
|
||||
for the main experiments used in the publication.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import data_formatters.qlib_Alpha158
|
||||
|
||||
|
||||
class ExperimentConfig(object):
|
||||
"""Defines experiment configs and paths to outputs.
|
||||
|
||||
Attributes:
|
||||
root_folder: Root folder to contain all experimental outputs.
|
||||
experiment: Name of experiment to run.
|
||||
data_folder: Folder to store data for experiment.
|
||||
model_folder: Folder to store serialised models.
|
||||
results_folder: Folder to store results.
|
||||
data_csv_path: Path to primary data csv file used in experiment.
|
||||
hyperparam_iterations: Default number of random search iterations for
|
||||
experiment.
|
||||
"""
|
||||
|
||||
default_experiments = ["Alpha158"]
|
||||
|
||||
def __init__(self, experiment="volatility", root_folder=None):
|
||||
"""Creates configs based on default experiment chosen.
|
||||
|
||||
Args:
|
||||
experiment: Name of experiment.
|
||||
root_folder: Root folder to save all outputs of training.
|
||||
"""
|
||||
|
||||
if experiment not in self.default_experiments:
|
||||
raise ValueError("Unrecognised experiment={}".format(experiment))
|
||||
|
||||
# Defines all relevant paths
|
||||
if root_folder is None:
|
||||
root_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "outputs")
|
||||
print("Using root folder {}".format(root_folder))
|
||||
|
||||
self.root_folder = root_folder
|
||||
self.experiment = experiment
|
||||
self.data_folder = os.path.join(root_folder, "data", experiment)
|
||||
self.model_folder = os.path.join(root_folder, "saved_models", experiment)
|
||||
self.results_folder = os.path.join(root_folder, "results", experiment)
|
||||
|
||||
# Creates folders if they don't exist
|
||||
for relevant_directory in [self.root_folder, self.data_folder, self.model_folder, self.results_folder]:
|
||||
if not os.path.exists(relevant_directory):
|
||||
os.makedirs(relevant_directory)
|
||||
|
||||
@property
|
||||
def data_csv_path(self):
|
||||
csv_map = {
|
||||
"Alpha158": "Alpha158.csv",
|
||||
}
|
||||
|
||||
return os.path.join(self.data_folder, csv_map[self.experiment])
|
||||
|
||||
@property
|
||||
def hyperparam_iterations(self):
|
||||
|
||||
return 240 if self.experiment == "volatility" else 60
|
||||
|
||||
def make_data_formatter(self):
|
||||
"""Gets a data formatter object for experiment.
|
||||
|
||||
Returns:
|
||||
Default DataFormatter per experiment.
|
||||
"""
|
||||
|
||||
data_formatter_class = {
|
||||
"Alpha158": data_formatters.qlib_Alpha158.Alpha158Formatter,
|
||||
}
|
||||
|
||||
return data_formatter_class[self.experiment]()
|
||||
14
examples/benchmarks/TFT/libs/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
430
examples/benchmarks/TFT/libs/hyperparam_opt.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Classes used for hyperparameter optimisation.
|
||||
|
||||
Two main classes exist:
|
||||
1) HyperparamOptManager used for optimisation on a single machine/GPU.
|
||||
2) DistributedHyperparamOptManager for multiple GPUs on different machines.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
import libs.utils as utils
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
Deque = collections.deque
|
||||
|
||||
|
||||
class HyperparamOptManager:
|
||||
"""Manages hyperparameter optimisation using random search for a single GPU.
|
||||
|
||||
Attributes:
|
||||
param_ranges: Discrete hyperparameter range for random search.
|
||||
results: Dataframe of validation results.
|
||||
fixed_params: Fixed model parameters per experiment.
|
||||
saved_params: Dataframe of parameters trained.
|
||||
best_score: Minimum validation loss observed thus far.
|
||||
optimal_name: Key to best configuration.
|
||||
hyperparam_folder: Where to save optimisation outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, param_ranges, fixed_params, model_folder, override_w_fixed_params=True):
|
||||
"""Instantiates model.
|
||||
|
||||
Args:
|
||||
param_ranges: Discrete hyperparameter range for random search.
|
||||
fixed_params: Fixed model parameters per experiment.
|
||||
model_folder: Folder to store optimisation artifacts.
|
||||
override_w_fixed_params: Whether to override serialsed fixed model
|
||||
parameters with new supplied values.
|
||||
"""
|
||||
|
||||
self.param_ranges = param_ranges
|
||||
|
||||
self._max_tries = 1000
|
||||
self.results = pd.DataFrame()
|
||||
self.fixed_params = fixed_params
|
||||
self.saved_params = pd.DataFrame()
|
||||
|
||||
self.best_score = np.Inf
|
||||
self.optimal_name = ""
|
||||
|
||||
# Setup
|
||||
# Create folder for saving if its not there
|
||||
self.hyperparam_folder = model_folder
|
||||
utils.create_folder_if_not_exist(self.hyperparam_folder)
|
||||
|
||||
self._override_w_fixed_params = override_w_fixed_params
|
||||
|
||||
def load_results(self):
|
||||
"""Loads results from previous hyperparameter optimisation.
|
||||
|
||||
Returns:
|
||||
A boolean indicating if previous results can be loaded.
|
||||
"""
|
||||
print("Loading results from", self.hyperparam_folder)
|
||||
|
||||
results_file = os.path.join(self.hyperparam_folder, "results.csv")
|
||||
params_file = os.path.join(self.hyperparam_folder, "params.csv")
|
||||
|
||||
if os.path.exists(results_file) and os.path.exists(params_file):
|
||||
|
||||
self.results = pd.read_csv(results_file, index_col=0)
|
||||
self.saved_params = pd.read_csv(params_file, index_col=0)
|
||||
|
||||
if not self.results.empty:
|
||||
self.results.at["loss"] = self.results.loc["loss"].apply(float)
|
||||
self.best_score = self.results.loc["loss"].min()
|
||||
|
||||
is_optimal = self.results.loc["loss"] == self.best_score
|
||||
self.optimal_name = self.results.T[is_optimal].index[0]
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_params_from_name(self, name):
|
||||
"""Returns previously saved parameters given a key."""
|
||||
params = self.saved_params
|
||||
|
||||
selected_params = dict(params[name])
|
||||
|
||||
if self._override_w_fixed_params:
|
||||
for k in self.fixed_params:
|
||||
selected_params[k] = self.fixed_params[k]
|
||||
|
||||
return selected_params
|
||||
|
||||
def get_best_params(self):
|
||||
"""Returns the optimal hyperparameters thus far."""
|
||||
|
||||
optimal_name = self.optimal_name
|
||||
|
||||
return self._get_params_from_name(optimal_name)
|
||||
|
||||
def clear(self):
|
||||
"""Clears all previous results and saved parameters."""
|
||||
shutil.rmtree(self.hyperparam_folder)
|
||||
os.makedirs(self.hyperparam_folder)
|
||||
self.results = pd.DataFrame()
|
||||
self.saved_params = pd.DataFrame()
|
||||
|
||||
def _check_params(self, params):
|
||||
"""Checks that parameter map is properly defined."""
|
||||
|
||||
valid_fields = list(self.param_ranges.keys()) + list(self.fixed_params.keys())
|
||||
invalid_fields = [k for k in params if k not in valid_fields]
|
||||
missing_fields = [k for k in valid_fields if k not in params]
|
||||
|
||||
if invalid_fields:
|
||||
raise ValueError("Invalid Fields Found {} - Valid ones are {}".format(invalid_fields, valid_fields))
|
||||
if missing_fields:
|
||||
raise ValueError("Missing Fields Found {} - Valid ones are {}".format(missing_fields, valid_fields))
|
||||
|
||||
def _get_name(self, params):
|
||||
"""Returns a unique key for the supplied set of params."""
|
||||
|
||||
self._check_params(params)
|
||||
|
||||
fields = list(params.keys())
|
||||
fields.sort()
|
||||
|
||||
return "_".join([str(params[k]) for k in fields])
|
||||
|
||||
def get_next_parameters(self, ranges_to_skip=None):
|
||||
"""Returns the next set of parameters to optimise.
|
||||
|
||||
Args:
|
||||
ranges_to_skip: Explicitly defines a set of keys to skip.
|
||||
"""
|
||||
if ranges_to_skip is None:
|
||||
ranges_to_skip = set(self.results.index)
|
||||
|
||||
if not isinstance(self.param_ranges, dict):
|
||||
raise ValueError("Only works for random search!")
|
||||
|
||||
param_range_keys = list(self.param_ranges.keys())
|
||||
param_range_keys.sort()
|
||||
|
||||
def _get_next():
|
||||
"""Returns next hyperparameter set per try."""
|
||||
|
||||
parameters = {k: np.random.choice(self.param_ranges[k]) for k in param_range_keys}
|
||||
|
||||
# Adds fixed params
|
||||
for k in self.fixed_params:
|
||||
parameters[k] = self.fixed_params[k]
|
||||
|
||||
return parameters
|
||||
|
||||
for _ in range(self._max_tries):
|
||||
|
||||
parameters = _get_next()
|
||||
name = self._get_name(parameters)
|
||||
|
||||
if name not in ranges_to_skip:
|
||||
return parameters
|
||||
|
||||
raise ValueError("Exceeded max number of hyperparameter searches!!")
|
||||
|
||||
def update_score(self, parameters, loss, model, info=""):
|
||||
"""Updates the results from last optimisation run.
|
||||
|
||||
Args:
|
||||
parameters: Hyperparameters used in optimisation.
|
||||
loss: Validation loss obtained.
|
||||
model: Model to serialised if required.
|
||||
info: Any ancillary information to tag on to results.
|
||||
|
||||
Returns:
|
||||
Boolean flag indicating if the model is the best seen so far.
|
||||
"""
|
||||
|
||||
if np.isnan(loss):
|
||||
loss = np.Inf
|
||||
|
||||
if not os.path.isdir(self.hyperparam_folder):
|
||||
os.makedirs(self.hyperparam_folder)
|
||||
|
||||
name = self._get_name(parameters)
|
||||
|
||||
is_optimal = self.results.empty or loss < self.best_score
|
||||
|
||||
# save the first model
|
||||
if is_optimal:
|
||||
# Try saving first, before updating info
|
||||
if model is not None:
|
||||
print("Optimal model found, updating")
|
||||
model.save(self.hyperparam_folder)
|
||||
self.best_score = loss
|
||||
self.optimal_name = name
|
||||
|
||||
self.results[name] = pd.Series({"loss": loss, "info": info})
|
||||
self.saved_params[name] = pd.Series(parameters)
|
||||
|
||||
self.results.to_csv(os.path.join(self.hyperparam_folder, "results.csv"))
|
||||
self.saved_params.to_csv(os.path.join(self.hyperparam_folder, "params.csv"))
|
||||
|
||||
return is_optimal
|
||||
|
||||
|
||||
class DistributedHyperparamOptManager(HyperparamOptManager):
|
||||
"""Manages distributed hyperparameter optimisation across many gpus."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
param_ranges,
|
||||
fixed_params,
|
||||
root_model_folder,
|
||||
worker_number,
|
||||
search_iterations=1000,
|
||||
num_iterations_per_worker=5,
|
||||
clear_serialised_params=False,
|
||||
):
|
||||
"""Instantiates optimisation manager.
|
||||
|
||||
This hyperparameter optimisation pre-generates #search_iterations
|
||||
hyperparameter combinations and serialises them
|
||||
at the start. At runtime, each worker goes through their own set of
|
||||
parameter ranges. The pregeneration
|
||||
allows for multiple workers to run in parallel on different machines without
|
||||
resulting in parameter overlaps.
|
||||
|
||||
Args:
|
||||
param_ranges: Discrete hyperparameter range for random search.
|
||||
fixed_params: Fixed model parameters per experiment.
|
||||
root_model_folder: Folder to store optimisation artifacts.
|
||||
worker_number: Worker index definining which set of hyperparameters to
|
||||
test.
|
||||
search_iterations: Maximum numer of random search iterations.
|
||||
num_iterations_per_worker: How many iterations are handled per worker.
|
||||
clear_serialised_params: Whether to regenerate hyperparameter
|
||||
combinations.
|
||||
"""
|
||||
|
||||
max_workers = int(np.ceil(search_iterations / num_iterations_per_worker))
|
||||
|
||||
# Sanity checks
|
||||
if worker_number > max_workers:
|
||||
raise ValueError(
|
||||
"Worker number ({}) cannot be larger than the total number of workers!".format(max_workers)
|
||||
)
|
||||
if worker_number > search_iterations:
|
||||
raise ValueError(
|
||||
"Worker number ({}) cannot be larger than the max search iterations ({})!".format(
|
||||
worker_number, search_iterations
|
||||
)
|
||||
)
|
||||
|
||||
print("*** Creating hyperparameter manager for worker {} ***".format(worker_number))
|
||||
|
||||
hyperparam_folder = os.path.join(root_model_folder, str(worker_number))
|
||||
super().__init__(param_ranges, fixed_params, hyperparam_folder, override_w_fixed_params=True)
|
||||
|
||||
serialised_ranges_folder = os.path.join(root_model_folder, "hyperparams")
|
||||
if clear_serialised_params:
|
||||
print("Regenerating hyperparameter list")
|
||||
if os.path.exists(serialised_ranges_folder):
|
||||
shutil.rmtree(serialised_ranges_folder)
|
||||
|
||||
utils.create_folder_if_not_exist(serialised_ranges_folder)
|
||||
|
||||
self.serialised_ranges_path = os.path.join(serialised_ranges_folder, "ranges_{}.csv".format(search_iterations))
|
||||
self.hyperparam_folder = hyperparam_folder # override
|
||||
self.worker_num = worker_number
|
||||
self.total_search_iterations = search_iterations
|
||||
self.num_iterations_per_worker = num_iterations_per_worker
|
||||
self.global_hyperparam_df = self.load_serialised_hyperparam_df()
|
||||
self.worker_search_queue = self._get_worker_search_queue()
|
||||
|
||||
@property
|
||||
def optimisation_completed(self):
|
||||
return False if self.worker_search_queue else True
|
||||
|
||||
def get_next_parameters(self):
|
||||
"""Returns next dictionary of hyperparameters to optimise."""
|
||||
param_name = self.worker_search_queue.pop()
|
||||
|
||||
params = self.global_hyperparam_df.loc[param_name, :].to_dict()
|
||||
|
||||
# Always override!
|
||||
for k in self.fixed_params:
|
||||
print("Overriding saved {}: {}".format(k, self.fixed_params[k]))
|
||||
|
||||
params[k] = self.fixed_params[k]
|
||||
|
||||
return params
|
||||
|
||||
def load_serialised_hyperparam_df(self):
|
||||
"""Loads serialsed hyperparameter ranges from file.
|
||||
|
||||
Returns:
|
||||
DataFrame containing hyperparameter combinations.
|
||||
"""
|
||||
print(
|
||||
"Loading params for {} search iterations form {}".format(
|
||||
self.total_search_iterations, self.serialised_ranges_path
|
||||
)
|
||||
)
|
||||
|
||||
if os.path.exists(self.serialised_ranges_folder):
|
||||
df = pd.read_csv(self.serialised_ranges_path, index_col=0)
|
||||
else:
|
||||
print("Unable to load - regenerating serach ranges instead")
|
||||
df = self.update_serialised_hyperparam_df()
|
||||
|
||||
return df
|
||||
|
||||
def update_serialised_hyperparam_df(self):
|
||||
"""Regenerates hyperparameter combinations and saves to file.
|
||||
|
||||
Returns:
|
||||
DataFrame containing hyperparameter combinations.
|
||||
"""
|
||||
search_df = self._generate_full_hyperparam_df()
|
||||
|
||||
print(
|
||||
"Serialising params for {} search iterations to {}".format(
|
||||
self.total_search_iterations, self.serialised_ranges_path
|
||||
)
|
||||
)
|
||||
|
||||
search_df.to_csv(self.serialised_ranges_path)
|
||||
|
||||
return search_df
|
||||
|
||||
def _generate_full_hyperparam_df(self):
|
||||
"""Generates actual hyperparameter combinations.
|
||||
|
||||
Returns:
|
||||
DataFrame containing hyperparameter combinations.
|
||||
"""
|
||||
|
||||
np.random.seed(131) # for reproducibility of hyperparam list
|
||||
|
||||
name_list = []
|
||||
param_list = []
|
||||
for _ in range(self.total_search_iterations):
|
||||
params = super().get_next_parameters(name_list)
|
||||
|
||||
name = self._get_name(params)
|
||||
|
||||
name_list.append(name)
|
||||
param_list.append(params)
|
||||
|
||||
full_search_df = pd.DataFrame(param_list, index=name_list)
|
||||
|
||||
return full_search_df
|
||||
|
||||
def clear(self): # reset when cleared
|
||||
"""Clears results for hyperparameter manager and resets."""
|
||||
super().clear()
|
||||
self.worker_search_queue = self._get_worker_search_queue()
|
||||
|
||||
def load_results(self):
|
||||
"""Load results from file and queue parameter combinations to try.
|
||||
|
||||
Returns:
|
||||
Boolean indicating if results were successfully loaded.
|
||||
"""
|
||||
success = super().load_results()
|
||||
|
||||
if success:
|
||||
self.worker_search_queue = self._get_worker_search_queue()
|
||||
|
||||
return success
|
||||
|
||||
def _get_worker_search_queue(self):
|
||||
"""Generates the queue of param combinations for current worker.
|
||||
|
||||
Returns:
|
||||
Queue of hyperparameter combinations outstanding.
|
||||
"""
|
||||
global_df = self.assign_worker_numbers(self.global_hyperparam_df)
|
||||
worker_df = global_df[global_df["worker"] == self.worker_num]
|
||||
|
||||
left_overs = [s for s in worker_df.index if s not in self.results.columns]
|
||||
|
||||
return Deque(left_overs)
|
||||
|
||||
def assign_worker_numbers(self, df):
|
||||
"""Updates parameter combinations with the index of the worker used.
|
||||
|
||||
Args:
|
||||
df: DataFrame of parameter combinations.
|
||||
|
||||
Returns:
|
||||
Updated DataFrame with worker number.
|
||||
"""
|
||||
output = df.copy()
|
||||
|
||||
n = self.total_search_iterations
|
||||
batch_size = self.num_iterations_per_worker
|
||||
|
||||
max_worker_num = int(np.ceil(n / batch_size))
|
||||
|
||||
worker_idx = np.concatenate([np.tile(i + 1, self.num_iterations_per_worker) for i in range(max_worker_num)])
|
||||
|
||||
output["worker"] = worker_idx[: len(output)]
|
||||
|
||||
return output
|
||||
1280
examples/benchmarks/TFT/libs/tft_model.py
Normal file
224
examples/benchmarks/TFT/libs/utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Generic helper functions used across codebase."""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
|
||||
|
||||
|
||||
# Generic.
|
||||
def get_single_col_by_input_type(input_type, column_definition):
|
||||
"""Returns name of single column.
|
||||
|
||||
Args:
|
||||
input_type: Input type of column to extract
|
||||
column_definition: Column definition list for experiment
|
||||
"""
|
||||
|
||||
l = [tup[0] for tup in column_definition if tup[2] == input_type]
|
||||
|
||||
if len(l) != 1:
|
||||
raise ValueError("Invalid number of columns for {}".format(input_type))
|
||||
|
||||
return l[0]
|
||||
|
||||
|
||||
def extract_cols_from_data_type(data_type, column_definition, excluded_input_types):
|
||||
"""Extracts the names of columns that correspond to a define data_type.
|
||||
|
||||
Args:
|
||||
data_type: DataType of columns to extract.
|
||||
column_definition: Column definition to use.
|
||||
excluded_input_types: Set of input types to exclude
|
||||
|
||||
Returns:
|
||||
List of names for columns with data type specified.
|
||||
"""
|
||||
return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types]
|
||||
|
||||
|
||||
# Loss functions.
|
||||
def tensorflow_quantile_loss(y, y_pred, quantile):
|
||||
"""Computes quantile loss for tensorflow.
|
||||
|
||||
Standard quantile loss as defined in the "Training Procedure" section of
|
||||
the main TFT paper
|
||||
|
||||
Args:
|
||||
y: Targets
|
||||
y_pred: Predictions
|
||||
quantile: Quantile to use for loss calculations (between 0 & 1)
|
||||
|
||||
Returns:
|
||||
Tensor for quantile loss.
|
||||
"""
|
||||
|
||||
# Checks quantile
|
||||
if quantile < 0 or quantile > 1:
|
||||
raise ValueError("Illegal quantile value={}! Values should be between 0 and 1.".format(quantile))
|
||||
|
||||
prediction_underflow = y - y_pred
|
||||
q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum(
|
||||
-prediction_underflow, 0.0
|
||||
)
|
||||
|
||||
return tf.reduce_sum(q_loss, axis=-1)
|
||||
|
||||
|
||||
def numpy_normalised_quantile_loss(y, y_pred, quantile):
|
||||
"""Computes normalised quantile loss for numpy arrays.
|
||||
|
||||
Uses the q-Risk metric as defined in the "Training Procedure" section of the
|
||||
main TFT paper.
|
||||
|
||||
Args:
|
||||
y: Targets
|
||||
y_pred: Predictions
|
||||
quantile: Quantile to use for loss calculations (between 0 & 1)
|
||||
|
||||
Returns:
|
||||
Float for normalised quantile loss.
|
||||
"""
|
||||
prediction_underflow = y - y_pred
|
||||
weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum(
|
||||
-prediction_underflow, 0.0
|
||||
)
|
||||
|
||||
quantile_loss = weighted_errors.mean()
|
||||
normaliser = y.abs().mean()
|
||||
|
||||
return 2 * quantile_loss / normaliser
|
||||
|
||||
|
||||
# OS related functions.
|
||||
def create_folder_if_not_exist(directory):
|
||||
"""Creates folder if it doesn't exist.
|
||||
|
||||
Args:
|
||||
directory: Folder path to create.
|
||||
"""
|
||||
# Also creates directories recursively
|
||||
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Tensorflow related functions.
|
||||
def get_default_tensorflow_config(tf_device="gpu", gpu_id=0):
|
||||
"""Creates tensorflow config for graphs to run on CPU or GPU.
|
||||
|
||||
Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi
|
||||
GPU machines.
|
||||
|
||||
Args:
|
||||
tf_device: 'cpu' or 'gpu'
|
||||
gpu_id: GPU ID to use if relevant
|
||||
|
||||
Returns:
|
||||
Tensorflow config.
|
||||
"""
|
||||
|
||||
if tf_device == "cpu":
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # for training on cpu
|
||||
tf_config = tf.ConfigProto(log_device_placement=False, device_count={"GPU": 0})
|
||||
|
||||
else:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
print("Selecting GPU ID={}".format(gpu_id))
|
||||
|
||||
tf_config = tf.ConfigProto(log_device_placement=False)
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
|
||||
return tf_config
|
||||
|
||||
|
||||
def save(tf_session, model_folder, cp_name, scope=None):
|
||||
"""Saves Tensorflow graph to checkpoint.
|
||||
|
||||
Saves all trainiable variables under a given variable scope to checkpoint.
|
||||
|
||||
Args:
|
||||
tf_session: Session containing graph
|
||||
model_folder: Folder to save models
|
||||
cp_name: Name of Tensorflow checkpoint
|
||||
scope: Variable scope containing variables to save
|
||||
"""
|
||||
# Save model
|
||||
if scope is None:
|
||||
saver = tf.train.Saver()
|
||||
else:
|
||||
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
|
||||
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
|
||||
|
||||
save_path = saver.save(tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name)))
|
||||
print("Model saved to: {0}".format(save_path))
|
||||
|
||||
|
||||
def load(tf_session, model_folder, cp_name, scope=None, verbose=False):
|
||||
"""Loads Tensorflow graph from checkpoint.
|
||||
|
||||
Args:
|
||||
tf_session: Session to load graph into
|
||||
model_folder: Folder containing serialised model
|
||||
cp_name: Name of Tensorflow checkpoint
|
||||
scope: Variable scope to use.
|
||||
verbose: Whether to print additional debugging information.
|
||||
"""
|
||||
# Load model proper
|
||||
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
|
||||
|
||||
print("Loading model from {0}".format(load_path))
|
||||
|
||||
print_weights_in_checkpoint(model_folder, cp_name)
|
||||
|
||||
initial_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
|
||||
|
||||
# Saver
|
||||
if scope is None:
|
||||
saver = tf.train.Saver()
|
||||
else:
|
||||
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
|
||||
saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
|
||||
# Load
|
||||
saver.restore(tf_session, load_path)
|
||||
all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
|
||||
|
||||
if verbose:
|
||||
print("Restored {0}".format(",".join(initial_vars.difference(all_vars))))
|
||||
print("Existing {0}".format(",".join(all_vars.difference(initial_vars))))
|
||||
print("All {0}".format(",".join(all_vars)))
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
def print_weights_in_checkpoint(model_folder, cp_name):
|
||||
"""Prints all weights in Tensorflow checkpoint.
|
||||
|
||||
Args:
|
||||
model_folder: Folder containing checkpoint
|
||||
cp_name: Name of checkpoint
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
|
||||
|
||||
print_tensors_in_checkpoint_file(file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True)
|
||||
3
examples/benchmarks/TFT/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
tensorflow-gpu==1.15.0
|
||||
numpy == 1.19.4
|
||||
pandas==1.1.0
|
||||
248
examples/benchmarks/TFT/tft.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow.compat.v1 as tf
|
||||
import data_formatters.base
|
||||
import expt_settings.configs
|
||||
import libs.hyperparam_opt
|
||||
import libs.tft_model
|
||||
import libs.utils as utils
|
||||
import os
|
||||
import datetime as dte
|
||||
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
# To register new datasets, please add them here.
|
||||
ALLOW_DATASET = ["Alpha158"]
|
||||
DATASET_SETTING = {
|
||||
"Alpha158": {
|
||||
"feature_col": ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "ROC60", "RESI10"],
|
||||
"label_col": ["LABEL0"],
|
||||
},
|
||||
}
|
||||
# To register new datasets, please add their configurations here.
|
||||
|
||||
|
||||
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
|
||||
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
|
||||
|
||||
|
||||
def fill_test_na(test_df):
|
||||
test_df_res = test_df.copy()
|
||||
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
|
||||
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
|
||||
test_df_res.loc[:, feature_cols] = test_feature_fna
|
||||
return test_df_res
|
||||
|
||||
|
||||
def process_qlib_data(df, dataset, fillna=False):
|
||||
"""Prepare data to fit the TFT model.
|
||||
|
||||
Args:
|
||||
df: Original DataFrame.
|
||||
fillna: Whether to fill the data with the mean values.
|
||||
|
||||
Returns:
|
||||
Transformed DataFrame.
|
||||
|
||||
"""
|
||||
# Several features selected manually
|
||||
feature_col = DATASET_SETTING[dataset]["feature_col"]
|
||||
label_col = DATASET_SETTING[dataset]["label_col"]
|
||||
temp_df = df.loc[:, feature_col + label_col]
|
||||
if fillna:
|
||||
temp_df = fill_test_na(temp_df)
|
||||
temp_df = temp_df.swaplevel()
|
||||
temp_df = temp_df.sort_index()
|
||||
temp_df = temp_df.reset_index(level=0)
|
||||
dates = pd.to_datetime(temp_df.index)
|
||||
temp_df["date"] = dates
|
||||
temp_df["day_of_week"] = dates.dayofweek
|
||||
temp_df["month"] = dates.month
|
||||
temp_df["year"] = dates.year
|
||||
temp_df["const"] = 1.0
|
||||
return temp_df
|
||||
|
||||
|
||||
def process_predicted(df, col_name):
|
||||
"""Transform the TFT predicted data into Qlib format.
|
||||
|
||||
Args:
|
||||
df: Original DataFrame.
|
||||
fillna: New column name.
|
||||
|
||||
Returns:
|
||||
Transformed DataFrame.
|
||||
|
||||
"""
|
||||
df_res = df.copy()
|
||||
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+5": col_name})
|
||||
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
|
||||
df_res = df_res[[col_name]]
|
||||
return df_res
|
||||
|
||||
|
||||
def format_score(forecast_df, col_name="pred", label_shift=5):
|
||||
pred = process_predicted(forecast_df, col_name=col_name)
|
||||
pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
|
||||
pred = pred.dropna()[col_name]
|
||||
return pred
|
||||
|
||||
|
||||
def transform_df(df, col_name="LABEL0"):
|
||||
df_res = df["feature"]
|
||||
df_res[col_name] = df["label"]
|
||||
return df_res
|
||||
|
||||
|
||||
class TFTModel(ModelFT):
|
||||
"""TFT Model"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.model = None
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
return transform_df(df_train), transform_df(df_valid)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
DATASET="Alpha158",
|
||||
MODEL_FOLDER="qlib_alpha158_model",
|
||||
LABEL_COL="LABEL0",
|
||||
LABEL_SHIFT=5,
|
||||
USE_GPU_ID=0,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if DATASET not in ALLOW_DATASET:
|
||||
raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
|
||||
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
|
||||
dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
|
||||
|
||||
train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
|
||||
valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
|
||||
|
||||
ExperimentConfig = expt_settings.configs.ExperimentConfig
|
||||
config = ExperimentConfig(DATASET)
|
||||
self.data_formatter = config.make_data_formatter()
|
||||
self.model_folder = MODEL_FOLDER
|
||||
self.gpu_id = USE_GPU_ID
|
||||
self.label_shift = LABEL_SHIFT
|
||||
self.expt_name = DATASET
|
||||
self.label_col = LABEL_COL
|
||||
|
||||
use_gpu = (True, self.gpu_id)
|
||||
# ===========================Training Process===========================
|
||||
ModelClass = libs.tft_model.TemporalFusionTransformer
|
||||
if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
|
||||
raise ValueError(
|
||||
"Data formatters should inherit from"
|
||||
+ "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
|
||||
)
|
||||
|
||||
default_keras_session = tf.keras.backend.get_session()
|
||||
|
||||
if use_gpu[0]:
|
||||
self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
|
||||
else:
|
||||
self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
|
||||
|
||||
self.data_formatter.set_scalers(train)
|
||||
|
||||
# Sets up default params
|
||||
fixed_params = self.data_formatter.get_experiment_params()
|
||||
params = self.data_formatter.get_default_model_params()
|
||||
|
||||
# Wendi: 合并调优的参数和非调优的参数
|
||||
params = {**params, **fixed_params}
|
||||
|
||||
if not os.path.exists(self.model_folder):
|
||||
os.makedirs(self.model_folder)
|
||||
params["model_folder"] = self.model_folder
|
||||
|
||||
print("*** Begin training ***")
|
||||
best_loss = np.Inf
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
self.tf_graph = tf.Graph()
|
||||
with self.tf_graph.as_default():
|
||||
self.sess = tf.Session(config=self.tf_config)
|
||||
tf.keras.backend.set_session(self.sess)
|
||||
self.model = ModelClass(params, use_cudnn=use_gpu[0])
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
self.model.fit(train_df=train, valid_df=valid)
|
||||
print("*** Finished training ***")
|
||||
saved_model_dir = self.model_folder + "/" + "saved_model"
|
||||
if not os.path.exists(saved_model_dir):
|
||||
os.makedirs(saved_model_dir)
|
||||
self.model.save(saved_model_dir)
|
||||
|
||||
def extract_numerical_data(data):
|
||||
"""Strips out forecast time and identifier columns."""
|
||||
return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
|
||||
|
||||
# p50_loss = utils.numpy_normalised_quantile_loss(
|
||||
# extract_numerical_data(targets), extract_numerical_data(p50_forecast),
|
||||
# 0.5)
|
||||
# p90_loss = utils.numpy_normalised_quantile_loss(
|
||||
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
|
||||
# 0.9)
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
print("Training completed.".format(dte.datetime.now()))
|
||||
# ===========================Training Process===========================
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
d_test = dataset.prepare("test", col_set=["feature", "label"])
|
||||
d_test = transform_df(d_test)
|
||||
d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
|
||||
test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
|
||||
|
||||
use_gpu = (True, self.gpu_id)
|
||||
# ===========================Predicting Process===========================
|
||||
default_keras_session = tf.keras.backend.get_session()
|
||||
|
||||
# Sets up default params
|
||||
fixed_params = self.data_formatter.get_experiment_params()
|
||||
params = self.data_formatter.get_default_model_params()
|
||||
params = {**params, **fixed_params}
|
||||
|
||||
print("*** Begin predicting ***")
|
||||
tf.reset_default_graph()
|
||||
|
||||
with self.tf_graph.as_default():
|
||||
tf.keras.backend.set_session(self.sess)
|
||||
output_map = self.model.predict(test, return_targets=True)
|
||||
targets = self.data_formatter.format_predictions(output_map["targets"])
|
||||
p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
|
||||
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
|
||||
predict = format_score(p90_forecast, "pred", 0) # self.label_shift
|
||||
label = format_score(targets, "label", 0)
|
||||
# ===========================Predicting Process===========================
|
||||
return predict, label
|
||||
|
||||
def finetune(self, dataset: DatasetH):
|
||||
"""
|
||||
finetune model
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
"""
|
||||
pass
|
||||
52
examples/benchmarks/TFT/workflow_config_tft.yaml
Normal file
@@ -0,0 +1,52 @@
|
||||
sys:
|
||||
rel_path: .
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TFTModel
|
||||
module_path: tft
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -44,7 +44,7 @@ task:
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
class: ALPHA360_Denoise
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
|
||||
@@ -29,18 +29,15 @@ task:
|
||||
class: XGBModel
|
||||
module_path: qlib.contrib.model.xgboost
|
||||
kwargs:
|
||||
objective: reg:linear
|
||||
n_estimators: 5000
|
||||
colsample_bytree: 0.85
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
missing: -1
|
||||
min_child_weight: 1
|
||||
eval_metric: rmse
|
||||
colsample_bytree: 0.5
|
||||
eta: 0.2
|
||||
gamma: 0.55
|
||||
max_depth: 2
|
||||
min_child_weight: 1.0
|
||||
n_estimators: 647
|
||||
subsample: 0.8
|
||||
nthread: 4
|
||||
tree_method: hist
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
446
examples/portfolio_optimization_example.ipynb
Normal file
@@ -0,0 +1,446 @@
|
||||
{
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9-final"
|
||||
},
|
||||
"orig_nbformat": 2,
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2,
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import copy\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import qlib\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
" risk_analysis,\n",
|
||||
")\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
"from qlib.utils import flatten_dict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"[35366:MainThread](2020-11-27 10:31:09,528) INFO - qlib.Initialization - [__init__.py:41] - default_conf: client.\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:09,531) WARNING - qlib.Initialization - [__init__.py:57] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:09,531) INFO - qlib.Initialization - [__init__.py:76] - qlib successfully initialized based on client settings.\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:09,532) INFO - qlib.Initialization - [__init__.py:79] - data_path=/home/dongzho/.qlib/qlib_data/cn_data\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# use default data\n",
|
||||
"# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
|
||||
"provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
|
||||
"if not exists_qlib_data(provider_uri):\n",
|
||||
" print(f\"Qlib data is not found in {provider_uri}\")\n",
|
||||
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
|
||||
" from get_data import GetData\n",
|
||||
" GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
|
||||
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"market = \"csi300\"\n",
|
||||
"benchmark = \"SH000300\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"## Model Training"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"[35366:MainThread](2020-11-27 10:31:29,731) INFO - qlib.timer - [log.py:81] - Time cost: 20.103s | Loading data Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:30,557) INFO - qlib.timer - [log.py:81] - Time cost: 0.241s | DropnaLabel Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,518) INFO - qlib.timer - [log.py:81] - Time cost: 7.960s | CSZScoreNorm Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,519) INFO - qlib.timer - [log.py:81] - Time cost: 8.786s | fit & process data Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,520) INFO - qlib.timer - [log.py:81] - Time cost: 28.891s | Init data Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,527) INFO - qlib.workflow - [exp.py:180] - Experiment 2 starts running ...\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,651) INFO - qlib.workflow - [recorder.py:234] - Recorder c81375e3b5474feb9c77711babd158c3 starts running under Experiment 2 ...\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:38,652) INFO - qlib.workflow - [expm.py:251] - No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.\n",
|
||||
"Training until validation scores don't improve for 50 rounds\n",
|
||||
"[20]\ttrain's l2: 0.990559\tvalid's l2: 0.994332\n",
|
||||
"[40]\ttrain's l2: 0.98687\tvalid's l2: 0.993702\n",
|
||||
"[60]\ttrain's l2: 0.984308\tvalid's l2: 0.993503\n",
|
||||
"[80]\ttrain's l2: 0.982202\tvalid's l2: 0.993446\n",
|
||||
"[100]\ttrain's l2: 0.980318\tvalid's l2: 0.993423\n",
|
||||
"[120]\ttrain's l2: 0.97854\tvalid's l2: 0.993409\n",
|
||||
"[140]\ttrain's l2: 0.97679\tvalid's l2: 0.993413\n",
|
||||
"[160]\ttrain's l2: 0.975116\tvalid's l2: 0.993473\n",
|
||||
"Early stopping, best iteration is:\n",
|
||||
"[127]\ttrain's l2: 0.977957\tvalid's l2: 0.993381\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# train model\n",
|
||||
"###################################\n",
|
||||
"data_handler_config = {\n",
|
||||
" \"start_time\": \"2008-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"fit_start_time\": \"2008-01-01\",\n",
|
||||
" \"fit_end_time\": \"2014-12-31\",\n",
|
||||
" \"instruments\": market,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"task = {\n",
|
||||
" \"model\": {\n",
|
||||
" \"class\": \"LGBModel\",\n",
|
||||
" \"module_path\": \"qlib.contrib.model.gbdt\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"loss\": \"mse\",\n",
|
||||
" \"colsample_bytree\": 0.8879,\n",
|
||||
" \"learning_rate\": 0.0421,\n",
|
||||
" \"subsample\": 0.8789,\n",
|
||||
" \"lambda_l1\": 205.6999,\n",
|
||||
" \"lambda_l2\": 580.9768,\n",
|
||||
" \"max_depth\": 8,\n",
|
||||
" \"num_leaves\": 210,\n",
|
||||
" \"num_threads\": 20,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"dataset\": {\n",
|
||||
" \"class\": \"DatasetH\",\n",
|
||||
" \"module_path\": \"qlib.data.dataset\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"handler\": {\n",
|
||||
" \"class\": \"Alpha158\",\n",
|
||||
" \"module_path\": \"qlib.contrib.data.handler\",\n",
|
||||
" \"kwargs\": data_handler_config,\n",
|
||||
" },\n",
|
||||
" \"segments\": {\n",
|
||||
" \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
|
||||
" \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
|
||||
" \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# model initiaiton\n",
|
||||
"model = init_instance_by_config(task[\"model\"])\n",
|
||||
"dataset = init_instance_by_config(task[\"dataset\"])\n",
|
||||
"\n",
|
||||
"# start exp to train model\n",
|
||||
"with R.start(experiment_name=\"train_model\"):\n",
|
||||
" R.log_params(**flatten_dict(task))\n",
|
||||
" model.fit(dataset)\n",
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
" rid = R.get_recorder().id\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"source": [
|
||||
"## Optimization Based Strategy"
|
||||
],
|
||||
"cell_type": "markdown",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.strategy.strategy import BaseStrategy\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class OptBasedStrategy(BaseStrategy):\n",
|
||||
" \"\"\"Optimization Based Strategy\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, data_handler, cov_estimator, optimizer):\n",
|
||||
" self.data_handler = data_handler\n",
|
||||
" self.cov_estimator = cov_estimator\n",
|
||||
" self.optimizer = optimizer\n",
|
||||
"\n",
|
||||
" def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):\n",
|
||||
" \"\"\"\n",
|
||||
" Parameters\n",
|
||||
" -----------\n",
|
||||
" score_series : pd.Seires\n",
|
||||
" stock_id , score.\n",
|
||||
" current : Position()\n",
|
||||
" current of account.\n",
|
||||
" trade_exchange : Exchange()\n",
|
||||
" exchange.\n",
|
||||
" trade_date : pd.Timestamp\n",
|
||||
" date.\n",
|
||||
" \"\"\"\n",
|
||||
" score_series = score_series.dropna()\n",
|
||||
"\n",
|
||||
" # check stock holdings, if\n",
|
||||
" # 1. doesn't have score: target amount = 0 (force sell)\n",
|
||||
" # 2. stock not tradable: target amount = current amount\n",
|
||||
" current_position = current.get_stock_amount_dict()\n",
|
||||
" target_position = {}\n",
|
||||
" for stock_id in current_position:\n",
|
||||
" if not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):\n",
|
||||
" target_position[stock_id] = current_position[stock_id]\n",
|
||||
" elif stock_id not in score_series.index:\n",
|
||||
" target_position[stock_id] = 0\n",
|
||||
" else:\n",
|
||||
" # need to be solved by optimizer\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" # filter scores, if\n",
|
||||
" # 1. kept in `amount_dict` by previous rules\n",
|
||||
" # 2. not tradable\n",
|
||||
" skipped = []\n",
|
||||
" for stock_id in score_series.index:\n",
|
||||
" if stock_id in target_position:\n",
|
||||
" skipped.append(stock_id)\n",
|
||||
" elif not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):\n",
|
||||
" skipped.append(stock_id)\n",
|
||||
" score_series = score_series[~score_series.index.isin(skipped)]\n",
|
||||
"\n",
|
||||
" # calc remaining value\n",
|
||||
" current_value = pd.Series({\n",
|
||||
" stock_id: current.get_stock_price(stock_id) * amount\n",
|
||||
" for stock_id, amount in current_position.items()\n",
|
||||
" })\n",
|
||||
" risk_total_value = self.get_risk_degree(trade_date) * current.calculate_value()\n",
|
||||
" traded_value = risk_total_value - current_value.loc[list(target_position)].sum()\n",
|
||||
"\n",
|
||||
" # portfolio init weight\n",
|
||||
" init_weight = current_value.reindex(score_series.index, fill_value=0)\n",
|
||||
" init_weight_sum = init_weight.sum()\n",
|
||||
" if init_weight_sum > 0:\n",
|
||||
" init_weight /= init_weight_sum\n",
|
||||
"\n",
|
||||
" # covariance estimation\n",
|
||||
" selector = (self.data_handler.get_range_selector(pred_date, 252), score_series.index)\n",
|
||||
" price = self.data_handler.fetch(selector, level=None, squeeze=True)\n",
|
||||
" cov = self.cov_estimator(price)\n",
|
||||
" cov = cov.reindex(\n",
|
||||
" index=score_series.index, \n",
|
||||
" columns=score_series.index, \n",
|
||||
" #fill_value=cov.max().max()\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # optimize target portfolio\n",
|
||||
" if init_weight.sum() > 0:\n",
|
||||
" target_weight = self.optimizer(cov, score_series, init_weight)\n",
|
||||
" else:\n",
|
||||
" target_weight = self.optimizer(cov, score_series)\n",
|
||||
" target_weight = target_weight[target_weight > 1e-6]\n",
|
||||
" for stock_id, weight in target_weight.items():\n",
|
||||
" try:\n",
|
||||
" target_position[stock_id] = int(traded_value * weight / trade_exchange.get_close(stock_id, pred_date))\n",
|
||||
" except Exception as e:\n",
|
||||
" # TODO: unknown exception\n",
|
||||
" print('Exception:', e)\n",
|
||||
"\n",
|
||||
" # for debug\n",
|
||||
" print('trade date:', trade_date)\n",
|
||||
" print('target weight:', target_weight.to_dict())\n",
|
||||
" print('target position:', target_position)\n",
|
||||
"\n",
|
||||
" # generate order list\n",
|
||||
" order_list = trade_exchange.generate_order_for_target_amount_position(\n",
|
||||
" target_position=target_position,\n",
|
||||
" current_position=current_position,\n",
|
||||
" trade_date=trade_date,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return order_list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.data.dataset.loader import QlibDataLoader\n",
|
||||
"from qlib.data.dataset.handler import DataHandler\n",
|
||||
"from qlib.model.riskmodel import ShrinkCovEstimator\n",
|
||||
"from qlib.portfolio.optimizer import PortfolioOptimizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"[35366:MainThread](2020-11-27 10:31:56,951) INFO - qlib.timer - [log.py:81] - Time cost: 6.763s | Loading data Done\n",
|
||||
"[35366:MainThread](2020-11-27 10:31:56,953) INFO - qlib.timer - [log.py:81] - Time cost: 6.766s | Init data Done\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"data_loader = QlibDataLoader([\"$close\"])\n",
|
||||
"data_handler = DataHandler(\"all\", \"2015-01-01\", \"2020-08-01\", data_loader)\n",
|
||||
"cov_estimator = ShrinkCovEstimator(nan_option=\"mask\")\n",
|
||||
"optimizer = PortfolioOptimizer(\"mvo\", lamb=2, delta=0.2, tol=1e-5)\n",
|
||||
"strategy = OptBasedStrategy(data_handler, cov_estimator, optimizer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stderr",
|
||||
"text": [
|
||||
"1': 0.08936553334387595, 'SH601800': 0.011014844457113308, 'SH601939': 0.013378001170219945, 'SH603993': 0.013820193926861863, 'SZ000338': 0.002455991798001457, 'SZ000423': 0.004893338273543826, 'SZ000538': 0.010686211189620477, 'SZ002065': 0.09095125419435357, 'SZ002074': 0.010299013738522475, 'SZ002085': 0.19844965949420615, 'SZ002236': 0.09210003831704765, 'SZ002310': 0.05664352912360013, 'SZ300017': 0.0197442255539771}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272224, 'SH600009': 604839, 'SH600018': 3097398, 'SH600028': 335726, 'SH600196': 23243, 'SH600276': 71634, 'SH600519': 17354, 'SH600585': 269686, 'SH600900': 2501521, 'SH601111': 2400659, 'SH601800': 334062, 'SH601939': 1283164, 'SH603993': 742901, 'SZ000338': 95285, 'SZ000423': 21697, 'SZ000538': 14518, 'SZ002065': 498253, 'SZ002074': 111674, 'SZ002085': 591507, 'SZ002236': 394197, 'SZ002310': 2202674, 'SZ300017': 206128}\n",
|
||||
"target weight: {'SH600000': 0.02310668460556249, 'SH600009': 0.06170206213753432, 'SH600018': 0.027608180837257277, 'SH600028': 0.00971532319525714, 'SH600196': 0.0036133308423111116, 'SH600276': 0.093195014492093, 'SH600519': 0.013476706174774766, 'SH600585': 0.036024919027310476, 'SH600660': 0.04512159672692613, 'SH600900': 0.12506534473579556, 'SH601939': 0.013494851810297546, 'SH603993': 0.07619418669734077, 'SZ000338': 0.0024673392047414363, 'SZ000423': 0.00485981529404862, 'SZ000538': 0.010602880875660015, 'SZ002065': 0.09064325205359221, 'SZ002074': 0.0011889996597580427, 'SZ002085': 0.1982091371262038, 'SZ002236': 0.09254320484936242, 'SZ002310': 0.05152917909181458, 'SZ002466': 0.00014732765084648903, 'SZ300017': 0.019490662910321074}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272079, 'SH600009': 604359, 'SH600018': 3095205, 'SH600028': 335471, 'SH600196': 23407, 'SH600276': 71567, 'SH600519': 17345, 'SH600585': 269447, 'SH600660': 129265, 'SH600900': 2499305, 'SH601939': 1282317, 'SH603993': 4058172, 'SZ000338': 95223, 'SZ000423': 21703, 'SZ000538': 14509, 'SZ002065': 497821, 'SZ002074': 12787, 'SZ002085': 590955, 'SZ002236': 393895, 'SZ002310': 2190685, 'SZ002466': 4483, 'SZ300017': 205994}\n",
|
||||
"target weight: {'SH600000': 0.0014042138463464568, 'SH600009': 0.11511740651805806, 'SH600018': 0.026968513725965638, 'SH600028': 0.009566603496832042, 'SH600150': 0.016339328084607228, 'SH600276': 0.09374974543357856, 'SH600489': 0.021876512936684123, 'SH600585': 0.035840818294258524, 'SH600900': 0.12414161958870683, 'SH601888': 0.005682635273269834, 'SH601939': 0.013289788356428228, 'SH603993': 0.07491407610535435, 'SZ000338': 0.002426716760042838, 'SZ000423': 0.00492071038737461, 'SZ000503': 0.005617017904986693, 'SZ000538': 0.010859006699485451, 'SZ002065': 0.08924691553942904, 'SZ002085': 0.19757848255238786, 'SZ002236': 0.09381012783787722, 'SZ002310': 0.03737359938389514, 'SZ300017': 0.01927616131502695}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16809, 'SH600009': 1075516, 'SH600018': 3091248, 'SH600028': 335128, 'SH600150': 114804, 'SH600276': 71473, 'SH600489': 66586, 'SH600585': 268644, 'SH600900': 2496175, 'SH601888': 173824, 'SH601939': 1281108, 'SH603993': 4052802, 'SZ000338': 95107, 'SZ000423': 21684, 'SZ000503': 80461, 'SZ000538': 14507, 'SZ002065': 497197, 'SZ002085': 590211, 'SZ002236': 393412, 'SZ002310': 1573728, 'SZ300017': 205818}\n",
|
||||
"target weight: {'SH600000': 0.0013962189421662084, 'SH600009': 0.09330267135244051, 'SH600018': 0.026443154116291615, 'SH600028': 0.009581412428525829, 'SH600150': 0.016443917649559808, 'SH600276': 0.09378402212481758, 'SH600703': 0.0005233118350013756, 'SH600741': 0.10117549074044105, 'SH600900': 0.12435147566444608, 'SH601888': 0.00560250787284307, 'SH601939': 0.013238798853730008, 'SH603993': 0.07455231781733267, 'SZ000423': 0.0048695925705555185, 'SZ000503': 0.006070996956328167, 'SZ000538': 0.010870567565742796, 'SZ002065': 0.08722983720892508, 'SZ002074': 0.00037126948590009574, 'SZ002085': 0.19840484837030906, 'SZ002236': 0.09365186287123867, 'SZ002310': 0.03806080531862309, 'SZ300017': 7.492025186876957e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16889, 'SH600009': 867443, 'SH600018': 3086467, 'SH600028': 334573, 'SH600150': 114383, 'SH600276': 71360, 'SH600703': 1760, 'SH600741': 665366, 'SH600900': 2491839, 'SH601888': 173465, 'SH601939': 1278590, 'SH603993': 4045939, 'SZ000423': 21674, 'SZ000503': 80212, 'SZ000538': 14499, 'SZ002065': 496361, 'SZ002074': 4086, 'SZ002085': 589224, 'SZ002236': 392766, 'SZ002310': 1571463, 'SZ300017': 805}\n",
|
||||
"target weight: {'SH600000': 0.0014143911110003147, 'SH600018': 0.026834186435965166, 'SH600028': 0.00961324990522086, 'SH600150': 0.015905361405158292, 'SH600276': 0.09486308638260738, 'SH600685': 1.0253334545374858e-06, 'SH600703': 0.0005108576602907958, 'SH600741': 0.10252334336233063, 'SH600900': 0.1250632059809011, 'SH601888': 0.005830869532670813, 'SH601939': 0.01336945356138906, 'SH603993': 0.07101851124599835, 'SZ000423': 0.004899981502195361, 'SZ000503': 0.006113894785564276, 'SZ000538': 0.011081925761176491, 'SZ000709': 1.06442568357325e-06, 'SZ002065': 0.08812103684766726, 'SZ002074': 0.0003564773234700175, 'SZ002085': 0.19097427428977284, 'SZ002236': 0.09299395368630246, 'SZ002310': 0.03841630892378685, 'SZ002475': 0.10001934454071283, 'SZ300017': 7.322667303400442e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16886, 'SH600018': 3080789, 'SH600028': 334087, 'SH600150': 114360, 'SH600276': 71234, 'SH600685': 10, 'SH600703': 1709, 'SH600741': 663932, 'SH600900': 2486951, 'SH601888': 173417, 'SH601939': 1276335, 'SH603993': 3740672, 'SZ000423': 21667, 'SZ000503': 80191, 'SZ000538': 14495, 'SZ000709': 11, 'SZ002065': 495371, 'SZ002074': 3867, 'SZ002085': 588051, 'SZ002236': 392002, 'SZ002310': 1568834, 'SZ002475': 1264636, 'SZ300017': 809}\n",
|
||||
"target weight: {'SH600000': 0.0013872765178790307, 'SH600018': 0.026321999857337998, 'SH600028': 0.009491029058787367, 'SH600150': 0.015749871987744815, 'SH600276': 0.09581999547114961, 'SH600703': 0.000518490273176083, 'SH600741': 0.1037547619508012, 'SH600900': 0.12396253436063161, 'SH601258': 0.02298494942988327, 'SH601888': 0.005915886046387033, 'SH601939': 0.013177336599075601, 'SH603993': 0.06888468621566025, 'SZ000423': 0.005102036718661418, 'SZ000503': 0.00602692511970311, 'SZ000538': 0.011127923667697532, 'SZ000709': 0.07688609680386178, 'SZ002065': 0.08693397271897534, 'SZ002074': 0.000347445594871718, 'SZ002085': 0.1905176824564206, 'SZ002236': 0.035835596544641496, 'SZ002475': 0.09918059167278087, 'SZ300017': 7.291118905149903e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16948, 'SH600018': 3086676, 'SH600028': 334750, 'SH600150': 114560, 'SH600276': 71372, 'SH600703': 1715, 'SH600741': 665129, 'SH600900': 2491433, 'SH601258': 4190669, 'SH601888': 174070, 'SH601939': 1278836, 'SH603993': 3747283, 'SZ000423': 21744, 'SZ000503': 80490, 'SZ000538': 14538, 'SZ000709': 871429, 'SZ002065': 496245, 'SZ002074': 3887, 'SZ002085': 589120, 'SZ002236': 145147, 'SZ002475': 1268582, 'SZ300017': 814}\n",
|
||||
"target weight: {'SH600000': 0.001373124016867567, 'SH600018': 0.02646941123076474, 'SH600028': 0.009458335378810856, 'SH600150': 0.015442533996257352, 'SH600276': 0.09620341387657301, 'SH600649': 0.012613476480118908, 'SH600703': 0.0005280976985716832, 'SH600741': 0.06577156829314017, 'SH600900': 0.12455488881029539, 'SH601258': 0.02270943336842379, 'SH601939': 0.013066707696697587, 'SH603993': 0.0649427819283919, 'SZ000423': 0.0051167756388828005, 'SZ000503': 0.006076486564538039, 'SZ000709': 0.0770418453012855, 'SZ000778': 0.08738918304165759, 'SZ002065': 0.08804613990036694, 'SZ002074': 0.00034315924263262563, 'SZ002085': 0.18241434394629127, 'SZ002475': 0.10035998625624482, 'SZ300017': 7.809604376099223e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16935, 'SH600018': 3089469, 'SH600028': 334906, 'SH600150': 114496, 'SH600276': 71430, 'SH600649': 337388, 'SH600703': 1714, 'SH600741': 419916, 'SH600900': 2493978, 'SH601258': 4194599, 'SH601939': 1279661, 'SH603993': 3750968, 'SZ000423': 21734, 'SZ000503': 80440, 'SZ000709': 872293, 'SZ000778': 366855, 'SZ002065': 496756, 'SZ002074': 3880, 'SZ002085': 564610, 'SZ002475': 1269872, 'SZ300017': 812}\n",
|
||||
"target weight: {'SH600000': 0.0013497287789570015, 'SH600018': 0.02647482761554837, 'SH600028': 0.00941080088689994, 'SH600150': 0.01556139303593115, 'SH600276': 0.09732218714743374, 'SH600649': 0.012606184789019243, 'SH600703': 0.0005334649726542859, 'SH600900': 0.12593267687041163, 'SH601258': 0.021199485570796834, 'SH601939': 0.013025993149697816, 'SH603993': 0.06446918682668012, 'SZ000423': 0.005311875734339093, 'SZ000503': 0.006125989728635501, 'SZ000709': 0.0707610058353687, 'SZ000778': 0.14004715956352495, 'SZ002065': 0.08746446321200681, 'SZ002074': 0.00033710686535540885, 'SZ002085': 0.15238971653801253, 'SZ002146': 0.042585776887618575, 'SZ002475': 0.10701429615740456, 'SZ300017': 7.667981013711115e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 17031, 'SH600018': 3109084, 'SH600028': 336978, 'SH600150': 115126, 'SH600276': 71888, 'SH600649': 339316, 'SH600703': 1724, 'SH600900': 2510148, 'SH601258': 4237748, 'SH601939': 1287810, 'SH603993': 3775382, 'SZ000423': 21853, 'SZ000503': 80885, 'SZ000709': 878077, 'SZ000778': 625157, 'SZ002065': 499988, 'SZ002074': 3901, 'SZ002085': 469624, 'SZ002146': 2000993, 'SZ002475': 1278084, 'SZ300017': 814}\n",
|
||||
"target weight: {'SH600000': 0.0013594926998639766, 'SH600009': 0.021101252574639438, 'SH600028': 0.009528554544265834, 'SH600150': 0.015013601602404225, 'SH600276': 0.09860402207319302, 'SH600649': 0.01292550325031454, 'SH600685': 0.00703471182662378, 'SH600703': 0.0005218767517596246, 'SH600900': 0.12786995199482584, 'SH601258': 0.04401496515184404, 'SH601398': 0.025932829520167643, 'SH601939': 0.0134408200189716, 'SH603993': 0.06319752369639879, 'SZ000423': 0.005221187626834546, 'SZ000503': 0.006085670359590286, 'SZ000568': 0.003081214755480397, 'SZ000709': 0.07061122716452324, 'SZ000778': 0.1379488795662632, 'SZ000839': 0.019142903464547063, 'SZ002065': 0.04714685528331623, 'SZ002074': 0.00033291622875151913, 'SZ002085': 0.11947661465752588, 'SZ002146': 0.043205942689553425, 'SZ002310': 0.0009243182551654129, 'SZ002475': 0.106199974013018, 'SZ300017': 7.709323254732814e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16933, 'SH600009': 196068, 'SH600028': 337025, 'SH600150': 115100, 'SH600276': 71926, 'SH600649': 339354, 'SH600685': 75328, 'SH600703': 1713, 'SH600900': 2511928, 'SH601258': 8791935, 'SH601398': 1146896, 'SH601939': 1288215, 'SH603993': 3777819, 'SZ000423': 21728, 'SZ000503': 80869, 'SZ000568': 10375, 'SZ000709': 878683, 'SZ000778': 625604, 'SZ000839': 312116, 'SZ002065': 268413, 'SZ002074': 3860, 'SZ002085': 369761, 'SZ002146': 2002072, 'SZ002310': 40341, 'SZ002475': 1278918, 'SZ300017': 811}\n",
|
||||
"target weight: {'SH600000': 0.0013764694393366029, 'SH600009': 0.021541655860797534, 'SH600028': 0.009752609535237182, 'SH600276': 0.06514222178877259, 'SH600649': 0.01273168785031133, 'SH600685': 0.006989932070614982, 'SH600900': 0.12998548252109676, 'SH601258': 0.13157540821422453, 'SH601398': 0.02641881439805636, 'SH601939': 0.0136141957873422, 'SH603993': 0.0602411123337629, 'SZ000503': 0.006084251045333903, 'SZ000709': 0.06977363144499521, 'SZ000778': 0.1385461140272643, 'SZ000839': 0.018579865431307987, 'SZ002065': 0.046270476942690986, 'SZ002074': 0.00025974854597178115, 'SZ002085': 0.10060756172850334, 'SZ002146': 0.043204792194791966, 'SZ002310': 0.0009022784286642987, 'SZ002466': 0.011748866835406593, 'SZ002475': 0.08457581284822364, 'SZ300017': 7.701070501151889e-05}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16938, 'SH600009': 196239, 'SH600028': 337355, 'SH600276': 46535, 'SH600649': 339479, 'SH600685': 75274, 'SH600900': 2514488, 'SH601258': 26730440, 'SH601398': 1148157, 'SH601939': 1289259, 'SH603993': 3781937, 'SZ000503': 80900, 'SZ000709': 879645, 'SZ000778': 626285, 'SZ000839': 312384, 'SZ002065': 268717, 'SZ002074': 3093, 'SZ002085': 309206, 'SZ002146': 2003901, 'SZ002310': 39782, 'SZ002466': 367691, 'SZ002475': 1026389, 'SZ300017': 812}\n",
|
||||
"target weight: {'SH600000': 0.0013689894888766726, 'SH600009': 0.021087495457198752, 'SH600028': 0.009589419355091226, 'SH600276': 0.0644304399184473, 'SH600535': 0.016420787426513667, 'SH600649': 0.0267771761277641, 'SH600900': 0.12784455237901315, 'SH601169': 0.004374459372110214, 'SH601258': 0.13288651981531077, 'SH601398': 0.02615927477879055, 'SH601939': 0.013573361058977978, 'SH603993': 1.157895161672162e-06, 'SZ000503': 0.009069218941980683, 'SZ000709': 0.07014466816191627, 'SZ000778': 0.13956352821962528, 'SZ002065': 0.045206445945654664, 'SZ002085': 0.08649963592018277, 'SZ002146': 0.04234588186007612, 'SZ002310': 0.0008924777422846245, 'SZ002466': 0.07334842360184116, 'SZ002475': 0.08834296814868704, 'SZ300017': 7.311841306821287e-05}\n",
|
||||
"Exception: ('SH601169', Timestamp('2017-04-25 00:00:00'))\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16929, 'SH600009': 196092, 'SH600028': 337333, 'SH600276': 46571, 'SH600535': 57649, 'SH600649': 731641, 'SH600900': 2515321, 'SH601258': 26740467, 'SH601398': 1148635, 'SH601939': 1289434, 'SH603993': 72, 'SZ000503': 122157, 'SZ000709': 879908, 'SZ000778': 626506, 'SZ002065': 268767, 'SZ002085': 267906, 'SZ002146': 2004576, 'SZ002310': 39745, 'SZ002466': 2332750, 'SZ002475': 1026858, 'SZ300017': 806}\n",
|
||||
"target weight: {'SH600000': 0.0013439859873209908, 'SH600009': 0.02075652616964347, 'SH600028': 0.00939963933310415, 'SH600276': 0.06236017906066887, 'SH600535': 0.016369568294734148, 'SH600649': 0.025541724367766302, 'SH600900': 0.12768966131041845, 'SH601258': 0.1370446945486361, 'SH601398': 0.02601619218529119, 'SH601939': 0.013440958024818669, 'SH603993': 4.144559709761373e-06, 'SZ000503': 0.0084237188568659, 'SZ000568': 0.020576387679160105, 'SZ000709': 0.056783757531829446, 'SZ000778': 0.06920027928808208, 'SZ002008': 0.07943378393922318, 'SZ002065': 0.045339177613740886, 'SZ002085': 0.08505902525865962, 'SZ002146': 0.031624633954490035, 'SZ002310': 0.0008996156348854183, 'SZ002466': 0.0764983539831682, 'SZ002475': 0.086193992434369}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SZ300017': 812.4573136217659, 'SH600000': 16923, 'SH600009': 196076, 'SH600028': 337279, 'SH600276': 46567, 'SH600535': 57624, 'SH600649': 731549, 'SH600900': 2515891, 'SH601258': 26747448, 'SH601398': 1148886, 'SH601939': 1289307, 'SH603993': 263, 'SZ000503': 122158, 'SZ000568': 69471, 'SZ000709': 700781, 'SZ000778': 302643, 'SZ002008': 746285, 'SZ002065': 268804, 'SZ002085': 267988, 'SZ002146': 1473970, 'SZ002310': 39739, 'SZ002466': 2333288, 'SZ002475': 1027134}\n",
|
||||
"target weight: {'SH600000': 0.0014508867295425067, 'SH600009': 0.022137935734971876, 'SH600028': 0.01003980705499816, 'SH600276': 0.065554410760754, 'SH600535': 0.017337663954140436, 'SH600649': 0.026752732524884384, 'SH600900': 0.13610376526017787, 'SH601258': 0.14230666244775886, 'SH601398': 0.027847743092481312, 'SH601939': 0.014306563408357105, 'SH603993': 2.7770868647848817e-06, 'SZ000069': 0.10104502775773525, 'SZ000503': 0.009049444347506782, 'SZ000568': 0.005686495401232644, 'SZ000778': 0.0715782861850023, 'SZ002008': 0.08609584908472251, 'SZ002065': 0.04706561122827146, 'SZ002085': 0.09099179117275048, 'SZ002146': 0.03204301334262787, 'SZ002475': 0.09241758644387384, 'SZ300017': 0.00018594702102337797}\n",
|
||||
"target position: {'SZ000709': 700825.0269758024, 'SZ002299': 6184584.0980107365, 'SH600000': 16845, 'SH600009': 195098, 'SH600028': 335689, 'SH600276': 46340, 'SH600535': 57343, 'SH600649': 728078, 'SH600900': 2504242, 'SH601258': 26624542, 'SH601398': 1143577, 'SH601939': 1283067, 'SH603993': 160, 'SZ000069': 367637, 'SZ000503': 121565, 'SZ000568': 17626, 'SZ000778': 301250, 'SZ002008': 742790, 'SZ002065': 267559, 'SZ002085': 266737, 'SZ002146': 1467579, 'SZ002475': 1022346, 'SZ300017': 1776}\n",
|
||||
"target weight: {'SH600000': 0.0013484985106016394, 'SH600009': 0.020750773768622693, 'SH600028': 0.009285673867962157, 'SH600104': 2.9067007814076732e-05, 'SH600196': 0.10012804077099052, 'SH600276': 0.05943563439541343, 'SH600535': 0.015902136087846228, 'SH600649': 0.025189836387314323, 'SH600900': 0.12584805827140388, 'SH601111': 6.857382365314848e-06, 'SH601258': 0.03895938466363849, 'SH601398': 0.025753888553878806, 'SH601939': 0.013275755331575599, 'SH603993': 4.249178615404585e-06, 'SZ000069': 0.09445579375504781, 'SZ000503': 0.008532747266799033, 'SZ000568': 0.0052599046052527266, 'SZ000709': 0.06003418476540357, 'SZ000778': 0.06923031488245988, 'SZ002008': 0.07903025205993618, 'SZ002065': 0.04448484691775433, 'SZ002085': 0.08426354045447453, 'SZ002146': 0.031142767130486235, 'SZ002475': 0.08747938111190227, 'SZ300017': 0.00016841662419817417}\n",
|
||||
"target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16906, 'SH600009': 195107, 'SH600028': 335257, 'SH600104': 197, 'SH600196': 630404, 'SH600276': 46282, 'SH600535': 57311, 'SH600649': 727170, 'SH600900': 2500379, 'SH601111': 203, 'SH601258': 7443096, 'SH601398': 1142014, 'SH601939': 1281361, 'SH603993': 263, 'SZ000069': 366998, 'SZ000503': 121479, 'SZ000568': 17699, 'SZ000709': 699639, 'SZ000778': 300752, 'SZ002008': 741767, 'SZ002065': 267133, 'SZ002085': 266334, 'SZ002146': 1465489, 'SZ002475': 1020693, 'SZ300017': 1756}\n",
|
||||
"target weight: {'SH600000': 0.0012976336004362882, 'SH600009': 0.0204756895024156, 'SH600028': 0.008883617000656601, 'SH600104': 2.592943319382378e-05, 'SH600196': 0.09617041827497698, 'SH600276': 0.05681162545715886, 'SH600535': 0.015294256733040745, 'SH600649': 0.02417676167926707, 'SH600900': 0.12233373885315162, 'SH601398': 0.024531954099214746, 'SH601628': 0.005044154324745466, 'SH601888': 0.09500034426651846, 'SH601939': 0.012657033879067425, 'SH603993': 4.079522960136806e-06, 'SZ000069': 0.09054142453059062, 'SZ000503': 0.008036587259744734, 'SZ000568': 0.0049533657881637655, 'SZ000778': 0.06904486736535222, 'SZ002008': 0.06688985213943154, 'SZ002065': 0.04278977877238287, 'SZ002085': 0.0820368284038888, 'SZ002299': 0.06899317887598991, 'SZ002475': 0.08384652594205952, 'SZ300017': 0.00016035416530955983}\n",
|
||||
"target position: {'SH601258': 7443495.190430395, 'SH600000': 16952, 'SH600009': 195676, 'SH600028': 336044, 'SH600104': 183, 'SH600196': 631454, 'SH600276': 46372, 'SH600535': 57498, 'SH600649': 728582, 'SH600900': 2504660, 'SH601398': 1143938, 'SH601628': 695470, 'SH601888': 2951253, 'SH601939': 1283887, 'SH603993': 255, 'SZ000069': 367641, 'SZ000503': 121875, 'SZ000568': 17775, 'SZ000778': 301255, 'SZ002008': 638620, 'SZ002065': 267645, 'SZ002085': 266802, 'SZ002299': 6194843, 'SZ002475': 1022527, 'SZ300017': 1765}\n",
|
||||
"target weight: {'SH600000': 0.0013469483722729403, 'SH600028': 0.009286467498269333, 'SH600104': 2.368500734977497e-05, 'SH600196': 0.10145424564201923, 'SH600276': 0.06002237364700993, 'SH600535': 0.01588332650422844, 'SH600649': 0.025440421851940002, 'SH600900': 0.1279028471227695, 'SH601258': 0.035917606048396986, 'SH601398': 0.02559318344055778, 'SH601628': 0.005221942888216608, 'SH601888': 0.14928498761757883, 'SH601939': 0.013161430940131148, 'SH603993': 4.350147095904942e-06, 'SZ000069': 0.14038473724819095, 'SZ000503': 0.008556251357999256, 'SZ000568': 0.005243511514392524, 'SZ002008': 0.06824325050397591, 'SZ002065': 0.04420632869308568, 'SZ002085': 0.074424247013131, 'SZ002299': 0.0010812901181988855, 'SZ002475': 0.0871460668952185, 'SZ300017': 0.00017049992832446128}\n",
|
||||
"target position: {'SZ000778': 301254.84776855103, 'SH600000': 16873, 'SH600028': 335064, 'SH600104': 156, 'SH600196': 629613, 'SH600276': 46235, 'SH600535': 57245, 'SH600649': 726346, 'SH600900': 2497776, 'SH601258': 7423462, 'SH601398': 1140689, 'SH601628': 692346, 'SH601888': 4557826, 'SH601939': 1279908, 'SH603993': 261, 'SZ000069': 551887, 'SZ000503': 121344, 'SZ000568': 17697, 'SZ002008': 636943, 'SZ002065': 266904, 'SZ002085': 231781, 'SZ002299': 97527, 'SZ002475': 1019747, 'SZ300017': 1749}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"output_type": "error",
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m<ipython-input-49-2e7986244749>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[1;31m# backtest & analysis\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 31\u001b[0m \u001b[0mpar\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mport_analysis_config\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 32\u001b[1;33m \u001b[0mpar\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\workflow\\record_temp.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, **kwargs)\u001b[0m\n\u001b[0;32m 230\u001b[0m \u001b[1;31m# custom strategy and get backtest\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 231\u001b[0m \u001b[0mpred_score\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 232\u001b[1;33m \u001b[0mreport_normal\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpositions_normal\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnormal_backtest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpred_score\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstrategy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbacktest_config\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_objects\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"report_normal.pkl\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mreport_normal\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0martifact_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 234\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_objects\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"positions_normal.pkl\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mpositions_normal\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0martifact_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\contrib\\evaluate.py\u001b[0m in \u001b[0;36mbacktest\u001b[1;34m(pred, account, shift, benchmark, verbose, **kwargs)\u001b[0m\n\u001b[0;32m 269\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 270\u001b[0m \u001b[0maccount\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0maccount\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 271\u001b[1;33m \u001b[0mbenchmark\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbenchmark\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 272\u001b[0m )\n\u001b[0;32m 273\u001b[0m \u001b[1;31m# for compatibility of the old API. return the dict positions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\contrib\\backtest\\backtest.py\u001b[0m in \u001b[0;36mbacktest\u001b[1;34m(pred, strategy, trade_exchange, shift, verbose, account, benchmark)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[0mtrade_exchange\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrade_exchange\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[0mpred_date\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpred_date\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mtrade_date\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrade_date\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m )\n\u001b[0;32m 104\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m<ipython-input-46-65beeeee07c0>\u001b[0m in \u001b[0;36mgenerate_order_list\u001b[1;34m(self, score_series, current, trade_exchange, pred_date, trade_date)\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[1;31m# optimize target portfolio\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minit_weight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m \u001b[0mtarget_weight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcov\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscore_series\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minit_weight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 79\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[0mtarget_weight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcov\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscore_series\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[1;31m# optimize\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mw\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_optimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 104\u001b[0m \u001b[1;31m# restore index if needed\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_optimize\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 126\u001b[0m \u001b[1;31m# mean-variance\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 127\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmethod\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOPT_MVO\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 128\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_optimize_mvo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 129\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[1;31m# risk parity\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_optimize_mvo\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 162\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mlamb\u001b[0m\u001b[0;31m`\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mrisk\u001b[0m \u001b[0maversion\u001b[0m \u001b[0mparameter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 163\u001b[0m \"\"\"\n\u001b[1;32m--> 164\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_solve\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_objective_mvo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_constrains\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 165\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 166\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_optimize_rp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mS\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_solve\u001b[1;34m(self, n, obj, bounds, cons)\u001b[0m\n\u001b[0;32m 252\u001b[0m \u001b[1;31m# solve\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[0mx0\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mones\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mn\u001b[0m \u001b[1;31m# init results\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 254\u001b[1;33m \u001b[0msol\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mso\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mminimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwrapped_obj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbounds\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbounds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconstraints\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcons\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtol\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 255\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msol\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msuccess\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 256\u001b[0m \u001b[0mwarnings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"optimization not success ({sol.status})\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;32m~\\AppData\\Local\\Continuum\\miniconda3\\envs\\qlib\\lib\\site-packages\\scipy\\optimize\\_minimize.py\u001b[0m in \u001b[0;36mminimize\u001b[1;34m(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)\u001b[0m\n\u001b[0;32m 624\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mmeth\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'slsqp'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 625\u001b[0m return _minimize_slsqp(fun, x0, args, jac, bounds,\n\u001b[1;32m--> 626\u001b[1;33m constraints, callback=callback, **options)\n\u001b[0m\u001b[0;32m 627\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mmeth\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'trust-constr'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 628\u001b[0m return _minimize_trustregion_constr(fun, x0, args, jac, hess, hessp,\n",
|
||||
"\u001b[1;32m~\\AppData\\Local\\Continuum\\miniconda3\\envs\\qlib\\lib\\site-packages\\scipy\\optimize\\slsqp.py\u001b[0m in \u001b[0;36m_minimize_slsqp\u001b[1;34m(func, x0, args, jac, bounds, constraints, maxiter, ftol, iprint, disp, eps, callback, finite_diff_rel_step, **unknown_options)\u001b[0m\n\u001b[0;32m 419\u001b[0m n1, n2, n3)\n\u001b[0;32m 420\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 421\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# objective and constraint evaluation required\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 422\u001b[0m \u001b[0mfx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 423\u001b[0m \u001b[0mc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_eval_constraint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcons\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"###################################\n",
|
||||
"# prediction, backtest & analysis\n",
|
||||
"###################################\n",
|
||||
"port_analysis_config = {\n",
|
||||
" \"strategy\": strategy,\n",
|
||||
" \"backtest\": {\n",
|
||||
" \"verbose\": False,\n",
|
||||
" \"limit_threshold\": 0.095,\n",
|
||||
" \"account\": 100000000,\n",
|
||||
" \"benchmark\": benchmark,\n",
|
||||
" \"deal_price\": \"close\",\n",
|
||||
" \"open_cost\": 0.0005,\n",
|
||||
" \"close_cost\": 0.0015,\n",
|
||||
" \"min_cost\": 5,\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=\"backtest_analysis\"):\n",
|
||||
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
|
||||
" model = recorder.load_object(\"trained_model\")\n",
|
||||
"\n",
|
||||
" # prediction\n",
|
||||
" recorder = R.get_recorder()\n",
|
||||
" ba_rid = recorder.id\n",
|
||||
" sr = SignalRecord(model, dataset, recorder)\n",
|
||||
" sr.generate()\n",
|
||||
"\n",
|
||||
" # backtest & analysis\n",
|
||||
" par = PortAnaRecord(recorder, port_analysis_config)\n",
|
||||
" par.generate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import shutil
|
||||
import tempfile
|
||||
import statistics
|
||||
from pathlib import Path
|
||||
from operator import xor
|
||||
from subprocess import Popen, PIPE
|
||||
from threading import Thread
|
||||
from pprint import pprint
|
||||
@@ -174,11 +175,21 @@ def cal_mean_std(results) -> dict:
|
||||
|
||||
|
||||
# function to get all the folders benchmark folder
|
||||
def get_all_folders() -> dict:
|
||||
def get_all_folders(models, exclude) -> dict:
|
||||
folders = dict()
|
||||
if isinstance(models, str):
|
||||
model_list = models.split(",")
|
||||
models = [m.lower().strip("[ ]") for m in model_list]
|
||||
elif isinstance(models, list):
|
||||
models = [m.lower() for m in models]
|
||||
elif models is None:
|
||||
models = [f.name.lower() for f in os.scandir("benchmarks")]
|
||||
else:
|
||||
raise ValueError("Input models type is not supported. Please provide str or list without space.")
|
||||
for f in os.scandir("benchmarks"):
|
||||
path = Path("benchmarks") / f.name
|
||||
if f.name != "TFT":
|
||||
add = xor(bool(f.name.lower() in models), bool(exclude))
|
||||
if add:
|
||||
path = Path("benchmarks") / f.name
|
||||
folders[f.name] = str(path.resolve())
|
||||
return folders
|
||||
|
||||
@@ -226,13 +237,44 @@ def gen_and_save_md_table(metrics):
|
||||
|
||||
|
||||
# function to run the all the models
|
||||
def run(times=1):
|
||||
def run(times=1, models=None, exclude=False):
|
||||
"""
|
||||
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
|
||||
Any PR to enhance this method is highly welcomed.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
times : int
|
||||
determines how many times the model should be running.
|
||||
models : str or list
|
||||
determines the specific model or list of models to run or exclude.
|
||||
exclude : boolean
|
||||
determines whether the model being used is excluded or included.
|
||||
|
||||
Usage:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Case 1 - run all models multiple times
|
||||
python run_all_model.py 3
|
||||
|
||||
# Case 2 - run specific models multiple times
|
||||
python run_all_model.py 3 dnn
|
||||
|
||||
# Case 3 - run other models except those are given as arguments for multiple times
|
||||
python run_all_model.py 3 [dnn,tft,lstm] True
|
||||
|
||||
# Case 4 - run specific models for one time
|
||||
python run_all_model.py --models=[dnn,lightgbm]
|
||||
|
||||
# Case 5 - run other models except those are given as aruments for one time
|
||||
python run_all_model.py --models=[dnn,tft,sfm] --exclude=True
|
||||
|
||||
"""
|
||||
# get all folders
|
||||
folders = get_all_folders()
|
||||
folders = get_all_folders(models, exclude)
|
||||
# set up
|
||||
compatible = True
|
||||
if sys.version_info < (3, 3):
|
||||
|
||||
@@ -31,7 +31,8 @@
|
||||
")\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord"
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
"from qlib.utils import flatten_dict"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -129,7 +130,7 @@
|
||||
"\n",
|
||||
"# start exp to train model\n",
|
||||
"with R.start(experiment_name=\"train_model\"):\n",
|
||||
" R.log_paramters(**flatten_dict(task))\n",
|
||||
" R.log_params(**flatten_dict(task))\n",
|
||||
" model.fit(dataset)\n",
|
||||
" R.save_objects(trained_model=model)\n",
|
||||
" rid = R.get_recorder().id\n"
|
||||
@@ -337,4 +338,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
|
||||
138
examples/workflow_by_code_alstm.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "ALSTM",
|
||||
"module_path": "qlib.contrib.model.pytorch_alstm",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"loss": "mse",
|
||||
"seed": 0,
|
||||
"GPU": "0",
|
||||
"rnn_type": "GRU",
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -7,19 +7,15 @@ from pathlib import Path
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.pytorch_gats import GAT
|
||||
from qlib.contrib.data.handler import ALPHA360_Denoise
|
||||
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
|
||||
# from qlib.model.learner import train_model
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -65,17 +61,16 @@ if __name__ == "__main__":
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
"dropout": 0.7,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-3,
|
||||
"lr": 1e-4,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "loss",
|
||||
"loss": "mse",
|
||||
"base_model": "LSTM",
|
||||
"with_pretrain": True,
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
"GPU": "0",
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
@@ -98,7 +93,6 @@ if __name__ == "__main__":
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
# model = train_model(task)
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
@@ -70,7 +70,7 @@ if __name__ == "__main__":
|
||||
"lr": 1e-3,
|
||||
"early_stop": 20,
|
||||
"batch_size": 800,
|
||||
"metric": "IC",
|
||||
"metric": "loss",
|
||||
"loss": "mse",
|
||||
"seed": 0,
|
||||
"GPU": 0,
|
||||
|
||||
136
examples/workflow_by_code_hats.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
MARKET = "csi300"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
DATA_HANDLER_CONFIG = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": MARKET,
|
||||
}
|
||||
|
||||
TRAINER_CONFIG = {
|
||||
"train_start_time": "2008-01-01",
|
||||
"train_end_time": "2014-12-31",
|
||||
"validate_start_time": "2015-01-01",
|
||||
"validate_end_time": "2016-12-31",
|
||||
"test_start_time": "2017-01-01",
|
||||
"test_end_time": "2020-08-01",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "HATS",
|
||||
"module_path": "qlib.contrib.model.pytorch_hats",
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.7,
|
||||
"n_epochs": 200,
|
||||
"lr": 1e-4,
|
||||
"early_stop": 20,
|
||||
"metric": "loss",
|
||||
"loss": "mse",
|
||||
"base_model": "LSTM",
|
||||
"seed": 0,
|
||||
"GPU": "2",
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "ALPHA360_Denoise",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": DATA_HANDLER_CONFIG,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
# You shoud record the data in specific sequence
|
||||
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
|
||||
}
|
||||
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset, save_path="benchmarks/HATS/model_hat.pkl")
|
||||
|
||||
pred_score = model.predict(dataset)
|
||||
|
||||
# save pred_score to file
|
||||
pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
|
||||
pred_score_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
pred_score.to_pickle(pred_score_path)
|
||||
|
||||
###################################
|
||||
# backtest
|
||||
###################################
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
|
||||
# use default strategy
|
||||
# custom Strategy, refer to: TODO: Strategy API url
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
|
||||
|
||||
###################################
|
||||
# analyze
|
||||
# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
|
||||
###################################
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
print(analysis_df)
|
||||
@@ -1,5 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -62,21 +72,22 @@ if __name__ == "__main__":
|
||||
"kwargs": {
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"output_dim": 1,
|
||||
"freq_dim": 15,
|
||||
"output_dim": 32,
|
||||
"freq_dim": 25,
|
||||
"dropout_W": 0.5,
|
||||
"dropout_U": 0.5,
|
||||
"n_epochs": 10,
|
||||
"n_epochs": 15,
|
||||
"lr": 1e-3,
|
||||
"batch_size": 800,
|
||||
"metric": "",
|
||||
"batch_size": 1600,
|
||||
"early_stop": 20,
|
||||
"eval_steps": 5,
|
||||
"loss": "mse",
|
||||
"lr_decay": 0.96,
|
||||
"lr_decay_steps": 100,
|
||||
"optimizer": "gd",
|
||||
"GPU": 1,
|
||||
"seed": 0,
|
||||
"optimizer": "adam",
|
||||
"GPU": 3,
|
||||
"seed": 710,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
|
||||
@@ -64,7 +64,7 @@ class Config:
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
|
||||
NUM_USABLE_CPU = multiprocessing.cpu_count() - 2
|
||||
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
|
||||
|
||||
_default_config = {
|
||||
# data provider config
|
||||
|
||||
@@ -10,6 +10,28 @@ from inspect import getfullargspec
|
||||
import copy
|
||||
|
||||
|
||||
def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
args = getfullargspec(klass).args
|
||||
if "fit_start_time" in args and "fit_end_time" in args:
|
||||
assert (
|
||||
fit_start_time is not None and fit_end_time is not None
|
||||
), "Make sure `fit_start_time` and `fit_end_time` are not None."
|
||||
pkwargs.update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
|
||||
class ALPHA360_Denoise(DataHandlerLP):
|
||||
def __init__(self, instruments="csi500", start_time=None, end_time=None, fit_start_time=None, fit_end_time=None):
|
||||
data_loader = {
|
||||
@@ -83,28 +105,42 @@ class ALPHA360_Denoise(DataHandlerLP):
|
||||
return fields, names
|
||||
|
||||
|
||||
_DEFAULT_LEARN_PROCESSORS = [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
]
|
||||
_DEFAULT_INFER_PROCESSORS = [
|
||||
{"class": "ProcessInf", "kwargs": {}},
|
||||
{"class": "ZScoreNorm", "kwargs": {}},
|
||||
{"class": "Fillna", "kwargs": {}},
|
||||
]
|
||||
|
||||
|
||||
class ALPHA360(DataHandlerLP):
|
||||
def __init__(self, instruments="csi500", start_time=None, end_time=None, fit_start_time=None, fit_end_time=None):
|
||||
def __init__(
|
||||
self,
|
||||
instruments="csi500",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=_DEFAULT_INFER_PROCESSORS,
|
||||
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
**kwargs,
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": self.get_feature_config(),
|
||||
"label": self.get_label_config(),
|
||||
"label": kwargs.get("label", self.get_label_config()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
learn_processors = [
|
||||
{"class": "DropnaLabel", "kwargs": {"fields_group": "label"}},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
]
|
||||
infer_processors = [
|
||||
{"class": "ProcessInf", "kwargs": {}},
|
||||
{"class": "ZscoreNorm", "kwargs": {"fit_start_time": fit_start_time, "fit_end_time": fit_end_time}},
|
||||
{"class": "Fillna", "kwargs": {}},
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
instruments,
|
||||
start_time,
|
||||
@@ -168,39 +204,19 @@ class Alpha158(DataHandlerLP):
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=["DropnaLabel", {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}],
|
||||
learn_processors=_DEFAULT_LEARN_PROCESSORS,
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
process_type=DataHandlerLP.PTYPE_A
|
||||
**kwargs,
|
||||
):
|
||||
def check_transform_proc(proc_l):
|
||||
new_l = []
|
||||
for p in proc_l:
|
||||
if not isinstance(p, Processor):
|
||||
klass, pkwargs = get_cls_kwargs(p, processor_module)
|
||||
args = getfullargspec(klass).args
|
||||
if "fit_start_time" in args and "fit_end_time" in args:
|
||||
assert (
|
||||
fit_start_time is not None and fit_end_time is not None
|
||||
), "Make sure `fit_start_time` and `fit_end_time` are not None."
|
||||
pkwargs.update(
|
||||
{
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
return new_l
|
||||
|
||||
infer_processors = check_transform_proc(infer_processors)
|
||||
learn_processors = check_transform_proc(learn_processors)
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "QlibDataLoader",
|
||||
"kwargs": {
|
||||
"config": {"feature": self.get_feature_config(), "label": self.get_label_config()},
|
||||
"config": {"feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config())},
|
||||
},
|
||||
}
|
||||
super().__init__(
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import yaml
|
||||
import copy
|
||||
import os
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from ...config import REG_CN
|
||||
|
||||
|
||||
class EstimatorConfigManager(object):
|
||||
def __init__(self, config_path):
|
||||
|
||||
if not config_path:
|
||||
raise ValueError("Config path is invalid.")
|
||||
self.config_path = config_path
|
||||
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
self.config = copy.deepcopy(config)
|
||||
|
||||
self.ex_config = ExperimentConfig(config.get("experiment", dict()), self)
|
||||
self.data_config = DataConfig(config.get("data", dict()), self)
|
||||
self.model_config = ModelConfig(config.get("model", dict()), self)
|
||||
self.trainer_config = TrainerConfig(config.get("trainer", dict()), self)
|
||||
self.strategy_config = StrategyConfig(config.get("strategy", dict()), self)
|
||||
self.backtest_config = BacktestConfig(config.get("backtest", dict()), self)
|
||||
self.qlib_data_config = QlibDataConfig(config.get("qlib_data", dict()), self)
|
||||
|
||||
# If the start_date and end_date are not given in data_config, they will be referred from the trainer_config.
|
||||
handler_start_date = self.data_config.handler_parameters.get("start_date", None)
|
||||
handler_end_date = self.data_config.handler_parameters.get("end_date", None)
|
||||
if handler_start_date is None:
|
||||
self.data_config.handler_parameters["start_date"] = self.trainer_config.parameters["train_start_date"]
|
||||
if handler_end_date is None:
|
||||
self.data_config.handler_parameters["end_date"] = self.trainer_config.parameters["test_end_date"]
|
||||
|
||||
|
||||
class ExperimentConfig(object):
|
||||
TRAIN_MODE = "train"
|
||||
TEST_MODE = "test"
|
||||
|
||||
OBSERVER_FILE_STORAGE = "file_storage"
|
||||
OBSERVER_MONGO = "mongo"
|
||||
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for experiment
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.name = config.get("name", "test_experiment")
|
||||
# The dir of the result of all the experiments
|
||||
self.global_dir = config.get("dir", os.path.dirname(CONFIG_MANAGER.config_path))
|
||||
# The dir of the result of current experiment
|
||||
self.ex_dir = os.path.join(self.global_dir, self.name)
|
||||
if not os.path.exists(self.ex_dir):
|
||||
os.makedirs(self.ex_dir)
|
||||
self.tmp_run_dir = tempfile.mkdtemp(dir=self.ex_dir)
|
||||
self.mode = config.get("mode", ExperimentConfig.TRAIN_MODE)
|
||||
self.sacred_dir = os.path.join(self.ex_dir, "sacred")
|
||||
self.observer_type = config.get("observer_type", ExperimentConfig.OBSERVER_FILE_STORAGE)
|
||||
self.mongo_url = config.get("mongo_url", None)
|
||||
self.db_name = config.get("db_name", None)
|
||||
self.finetune = config.get("finetune", False)
|
||||
|
||||
# The path of the experiment id of the experiment
|
||||
self.exp_info_path = config.get("exp_info_path", os.path.join(self.ex_dir, "exp_info.json"))
|
||||
exp_info_dir = Path(self.exp_info_path).parent
|
||||
exp_info_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Test mode config
|
||||
loader_args = config.get("loader", dict())
|
||||
if self.mode == ExperimentConfig.TEST_MODE or self.finetune:
|
||||
loader_exp_info_path = loader_args.get("exp_info_path", None)
|
||||
self.loader_model_index = loader_args.get("model_index", None)
|
||||
if (loader_exp_info_path is not None) and (os.path.exists(loader_exp_info_path)):
|
||||
with open(loader_exp_info_path) as fp:
|
||||
loader_dict = json.load(fp)
|
||||
for k, v in loader_dict.items():
|
||||
setattr(self, "loader_{}".format(k), v)
|
||||
# Check loader experiment id
|
||||
assert hasattr(self, "loader_id"), "If mode is test or finetune is True, loader must contain id."
|
||||
else:
|
||||
self.loader_id = loader_args.get("id", None)
|
||||
if self.loader_id is None:
|
||||
raise ValueError("If mode is test or finetune is True, loader must contain id.")
|
||||
|
||||
self.loader_observer_type = loader_args.get("observer_type", self.observer_type)
|
||||
self.loader_name = loader_args.get("name", self.name)
|
||||
self.loader_dir = loader_args.get("dir", self.global_dir)
|
||||
|
||||
self.loader_mongo_url = loader_args.get("mongo_url", self.mongo_url)
|
||||
self.loader_db_name = loader_args.get("db_name", self.db_name)
|
||||
|
||||
|
||||
class DataConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for data
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.handler_module_path = config.get("module_path", "qlib.contrib.data.handler")
|
||||
self.handler_class = config.get("class", "ALPHA360")
|
||||
self.handler_parameters = config.get("args", dict())
|
||||
self.handler_filter = config.get("filter", dict())
|
||||
# Update provider uri.
|
||||
|
||||
|
||||
class ModelConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for model
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.model_class = config.get("class", "Model")
|
||||
self.model_module_path = config.get("module_path", "qlib.model")
|
||||
self.save_dir = os.path.join(CONFIG_MANAGER.ex_config.tmp_run_dir, "model")
|
||||
self.save_path = config.get("save_path", os.path.join(self.save_dir, "model.bin"))
|
||||
self.parameters = config.get("args", dict())
|
||||
# Make dir if need.
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
|
||||
|
||||
class TrainerConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for trainer
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.trainer_class = config.get("class", "StaticTrainer")
|
||||
self.trainer_module_path = config.get("module_path", "qlib.contrib.estimator.trainer")
|
||||
self.parameters = config.get("args", dict())
|
||||
|
||||
|
||||
class StrategyConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGER):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for strategy
|
||||
:param CONFIG_MANAGER: The estimator config manager
|
||||
"""
|
||||
self.strategy_class = config.get("class", "TopkDropoutStrategy")
|
||||
self.strategy_module_path = config.get("module_path", "qlib.contrib.strategy.strategy")
|
||||
self.parameters = config.get("args", dict())
|
||||
|
||||
|
||||
class BacktestConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGE):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for strategy
|
||||
:param CONFIG_MANAGE: The estimator config manager
|
||||
"""
|
||||
self.normal_backtest_parameters = config.get("normal_backtest_args", dict())
|
||||
self.long_short_backtest_parameters = config.get("long_short_backtest_args", dict())
|
||||
|
||||
|
||||
class QlibDataConfig(object):
|
||||
def __init__(self, config, CONFIG_MANAGE):
|
||||
"""__init__
|
||||
|
||||
:param config: The config dict for qlib_client
|
||||
:param CONFIG_MANAGE: The estimator config manager
|
||||
"""
|
||||
self.provider_uri = config.pop("provider_uri", "~/.qlib/qlib_data/cn_data")
|
||||
self.auto_mount = config.pop("auto_mount", False)
|
||||
self.mount_path = config.pop("mount_path", "~/.qlib/qlib_data/cn_data")
|
||||
self.region = config.pop("region", REG_CN)
|
||||
self.args = config
|
||||
@@ -1,328 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import yaml
|
||||
import pickle
|
||||
|
||||
import qlib
|
||||
from ..evaluate import risk_analysis
|
||||
from ..evaluate import backtest as normal_backtest
|
||||
from ..evaluate import long_short_backtest
|
||||
from .config import ExperimentConfig
|
||||
from .fetcher import create_fetcher_with_config
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_module_by_module_path, compare_dict_value
|
||||
|
||||
|
||||
class Estimator(object):
|
||||
def __init__(self, config_manager, sacred_ex):
|
||||
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("Estimator")
|
||||
|
||||
# 1. Set config manager.
|
||||
self.config_manager = config_manager
|
||||
|
||||
# 2. Set configs.
|
||||
self.ex_config = config_manager.ex_config
|
||||
self.data_config = config_manager.data_config
|
||||
self.model_config = config_manager.model_config
|
||||
self.trainer_config = config_manager.trainer_config
|
||||
self.strategy_config = config_manager.strategy_config
|
||||
self.backtest_config = config_manager.backtest_config
|
||||
|
||||
# If experiment.mode is test or experiment.finetune is True, load the experimental results in the loader
|
||||
if self.ex_config.mode == self.ex_config.TEST_MODE or self.ex_config.finetune:
|
||||
self.compare_config_with_config_manger(self.config_manager)
|
||||
|
||||
# 3. Set sacred_experiment.
|
||||
self.ex = sacred_ex
|
||||
|
||||
# 4. Init data handler.
|
||||
self.data_handler = None
|
||||
self._init_data_handler()
|
||||
|
||||
# 5. Init trainer.
|
||||
self.trainer = None
|
||||
self._init_trainer()
|
||||
|
||||
# 6. Init strategy.
|
||||
self.strategy = None
|
||||
self._init_strategy()
|
||||
|
||||
def _init_data_handler(self):
|
||||
handler_module = get_module_by_module_path(self.data_config.handler_module_path)
|
||||
|
||||
# Set market
|
||||
market = self.data_config.handler_filter.get("market", None)
|
||||
if market is None:
|
||||
if "market" in self.data_config.handler_parameters:
|
||||
self.logger.warning(
|
||||
"Warning: The market in data.args section is deprecated. "
|
||||
"It only works when market is not set in data.filter section. "
|
||||
"It will be overridden by market in the data.filter section."
|
||||
)
|
||||
market = self.data_config.handler_parameters["market"]
|
||||
else:
|
||||
market = "csi500"
|
||||
|
||||
self.data_config.handler_parameters["market"] = market
|
||||
|
||||
data_filter_list = []
|
||||
handler_filters = self.data_config.handler_filter.get("filter_pipeline", list())
|
||||
for h_filter in handler_filters:
|
||||
filter_module_path = h_filter.get("module_path", "qlib.data.filter")
|
||||
filter_class_name = h_filter.get("class", "")
|
||||
filter_parameters = h_filter.get("args", {})
|
||||
filter_module = get_module_by_module_path(filter_module_path)
|
||||
filter_class = getattr(filter_module, filter_class_name)
|
||||
data_filter = filter_class(**filter_parameters)
|
||||
data_filter_list.append(data_filter)
|
||||
|
||||
self.data_config.handler_parameters["data_filter_list"] = data_filter_list
|
||||
handler_class = getattr(handler_module, self.data_config.handler_class)
|
||||
self.data_handler = handler_class(**self.data_config.handler_parameters)
|
||||
|
||||
def _init_trainer(self):
|
||||
|
||||
model_module = get_module_by_module_path(self.model_config.model_module_path)
|
||||
trainer_module = get_module_by_module_path(self.trainer_config.trainer_module_path)
|
||||
model_class = getattr(model_module, self.model_config.model_class)
|
||||
trainer_class = getattr(trainer_module, self.trainer_config.trainer_class)
|
||||
|
||||
self.trainer = trainer_class(
|
||||
model_class,
|
||||
self.model_config.save_path,
|
||||
self.model_config.parameters,
|
||||
self.data_handler,
|
||||
self.ex,
|
||||
**self.trainer_config.parameters
|
||||
)
|
||||
|
||||
def _init_strategy(self):
|
||||
|
||||
module = get_module_by_module_path(self.strategy_config.strategy_module_path)
|
||||
strategy_class = getattr(module, self.strategy_config.strategy_class)
|
||||
self.strategy = strategy_class(**self.strategy_config.parameters)
|
||||
|
||||
def run(self):
|
||||
if self.ex_config.mode == ExperimentConfig.TRAIN_MODE:
|
||||
self.trainer.train()
|
||||
elif self.ex_config.mode == ExperimentConfig.TEST_MODE:
|
||||
self.trainer.load()
|
||||
else:
|
||||
raise ValueError("unexpected mode: %s" % self.ex_config.mode)
|
||||
analysis = self.backtest()
|
||||
print(analysis)
|
||||
self.logger.info(
|
||||
"experiment id: {}, experiment name: {}".format(self.ex.experiment.current_run._id, self.ex_config.name)
|
||||
)
|
||||
|
||||
# Remove temp dir
|
||||
# shutil.rmtree(self.ex_config.tmp_run_dir)
|
||||
|
||||
def backtest(self):
|
||||
TimeInspector.set_time_mark()
|
||||
# 1. Get pred and prediction score of model(s).
|
||||
pred = self.trainer.get_test_score()
|
||||
try:
|
||||
performance = self.trainer.get_test_performance()
|
||||
except NotImplementedError:
|
||||
performance = None
|
||||
# 2. Normal Backtest.
|
||||
report_normal, positions_normal = self._normal_backtest(pred)
|
||||
# 3. Long-Short Backtest.
|
||||
# Deprecated
|
||||
# long_short_reports = self._long_short_backtest(pred)
|
||||
# 4. Analyze
|
||||
analysis_df = self._analyze(report_normal)
|
||||
# 5. Save.
|
||||
self._save_backtest_result(
|
||||
pred,
|
||||
analysis_df,
|
||||
positions_normal,
|
||||
report_normal,
|
||||
# long_short_reports,
|
||||
performance,
|
||||
)
|
||||
return analysis_df
|
||||
|
||||
def _normal_backtest(self, pred):
|
||||
TimeInspector.set_time_mark()
|
||||
if "account" not in self.backtest_config.normal_backtest_parameters:
|
||||
if "account" in self.strategy_config.parameters:
|
||||
self.logger.warning(
|
||||
"Warning: The account in strategy section is deprecated. "
|
||||
"It only works when account is not set in backtest section. "
|
||||
"It will be overridden by account in the backtest section."
|
||||
)
|
||||
self.backtest_config.normal_backtest_parameters["account"] = self.strategy_config.parameters["account"]
|
||||
report_normal, positions_normal = normal_backtest(
|
||||
pred, strategy=self.strategy, **self.backtest_config.normal_backtest_parameters
|
||||
)
|
||||
TimeInspector.log_cost_time("Finished normal backtest.")
|
||||
return report_normal, positions_normal
|
||||
|
||||
def _long_short_backtest(self, pred):
|
||||
TimeInspector.set_time_mark()
|
||||
long_short_reports = long_short_backtest(pred, **self.backtest_config.long_short_backtest_parameters)
|
||||
TimeInspector.log_cost_time("Finished long-short backtest.")
|
||||
return long_short_reports
|
||||
|
||||
@staticmethod
|
||||
def _analyze(report_normal):
|
||||
TimeInspector.set_time_mark()
|
||||
|
||||
analysis = dict()
|
||||
# analysis["pred_long"] = risk_analysis(long_short_reports["long"])
|
||||
# analysis["pred_short"] = risk_analysis(long_short_reports["short"])
|
||||
# analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"])
|
||||
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
TimeInspector.log_cost_time(
|
||||
"Finished generating analysis," " average turnover is: {0:.4f}.".format(report_normal["turnover"].mean())
|
||||
)
|
||||
return analysis_df
|
||||
|
||||
def _save_backtest_result(self, pred, analysis, positions, report_normal, performance):
|
||||
# 1. Result dir.
|
||||
result_dir = os.path.join(self.config_manager.ex_config.tmp_run_dir, "result")
|
||||
if not os.path.exists(result_dir):
|
||||
os.makedirs(result_dir)
|
||||
|
||||
self.ex.add_info(
|
||||
"task_config",
|
||||
json.loads(json.dumps(self.config_manager.config, default=str)),
|
||||
)
|
||||
|
||||
# 2. Pred.
|
||||
TimeInspector.set_time_mark()
|
||||
pred_pkl_path = os.path.join(result_dir, "pred.pkl")
|
||||
pred.to_pickle(pred_pkl_path)
|
||||
self.ex.add_artifact(pred_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving pred.pkl to: {}".format(pred_pkl_path))
|
||||
|
||||
# 3. Ana.
|
||||
TimeInspector.set_time_mark()
|
||||
analysis_pkl_path = os.path.join(result_dir, "analysis.pkl")
|
||||
analysis.to_pickle(analysis_pkl_path)
|
||||
self.ex.add_artifact(analysis_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving analysis.pkl to: {}".format(analysis_pkl_path))
|
||||
|
||||
# 4. Pos.
|
||||
TimeInspector.set_time_mark()
|
||||
positions_pkl_path = os.path.join(result_dir, "positions.pkl")
|
||||
with open(positions_pkl_path, "wb") as fp:
|
||||
pickle.dump(positions, fp)
|
||||
self.ex.add_artifact(positions_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving positions.pkl to: {}".format(positions_pkl_path))
|
||||
|
||||
# 5. Report normal.
|
||||
TimeInspector.set_time_mark()
|
||||
report_normal_pkl_path = os.path.join(result_dir, "report_normal.pkl")
|
||||
report_normal.to_pickle(report_normal_pkl_path)
|
||||
self.ex.add_artifact(report_normal_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving report_normal.pkl to: {}".format(report_normal_pkl_path))
|
||||
|
||||
# 6. Report long short.
|
||||
# Deprecated
|
||||
# for k, name in zip(
|
||||
# ["long", "short", "long_short"],
|
||||
# ["report_long.pkl", "report_short.pkl", "report_long_short.pkl"],
|
||||
# ):
|
||||
# TimeInspector.set_time_mark()
|
||||
# pkl_path = os.path.join(result_dir, name)
|
||||
# long_short_reports[k].to_pickle(pkl_path)
|
||||
# self.ex.add_artifact(pkl_path)
|
||||
# TimeInspector.log_cost_time("Finished saving {} to: {}".format(name, pkl_path))
|
||||
|
||||
# 7. Origin test label.
|
||||
TimeInspector.set_time_mark()
|
||||
label_pkl_path = os.path.join(result_dir, "label.pkl")
|
||||
self.data_handler.get_origin_test_label_with_date(
|
||||
self.trainer_config.parameters["test_start_date"],
|
||||
self.trainer_config.parameters["test_end_date"],
|
||||
).to_pickle(label_pkl_path)
|
||||
self.ex.add_artifact(label_pkl_path)
|
||||
TimeInspector.log_cost_time("Finished saving label.pkl to: {}".format(label_pkl_path))
|
||||
|
||||
# 8. Experiment info, save the model(s) performance here.
|
||||
TimeInspector.set_time_mark()
|
||||
cur_ex_id = self.ex.experiment.current_run._id
|
||||
exp_info = {
|
||||
"id": cur_ex_id,
|
||||
"name": self.ex_config.name,
|
||||
"performance": performance,
|
||||
"observer_type": self.ex_config.observer_type,
|
||||
}
|
||||
|
||||
if self.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
|
||||
exp_info.update(
|
||||
{
|
||||
"mongo_url": self.ex_config.mongo_url,
|
||||
"db_name": self.ex_config.db_name,
|
||||
}
|
||||
)
|
||||
else:
|
||||
exp_info.update({"dir": self.ex_config.global_dir})
|
||||
|
||||
with open(self.ex_config.exp_info_path, "w") as fp:
|
||||
json.dump(exp_info, fp, indent=4, sort_keys=True)
|
||||
self.ex.add_artifact(self.ex_config.exp_info_path)
|
||||
TimeInspector.log_cost_time("Finished saving ex_info to: {}".format(self.ex_config.exp_info_path))
|
||||
|
||||
@staticmethod
|
||||
def compare_config_with_config_manger(config_manager):
|
||||
"""Compare loader model args and current config with ConfigManage
|
||||
|
||||
:param config_manager: ConfigManager
|
||||
:return:
|
||||
"""
|
||||
fetcher = create_fetcher_with_config(config_manager, load_form_loader=True)
|
||||
loader_mode_config = fetcher.get_experiment(
|
||||
exp_name=config_manager.ex_config.loader_name,
|
||||
exp_id=config_manager.ex_config.loader_id,
|
||||
fields=["task_config"],
|
||||
)["task_config"]
|
||||
with open(config_manager.config_path) as fp:
|
||||
current_config = yaml.load(fp.read())
|
||||
current_config = json.loads(json.dumps(current_config, default=str))
|
||||
|
||||
logger = get_module_logger("Estimator")
|
||||
|
||||
loader_mode_config = copy.deepcopy(loader_mode_config)
|
||||
current_config = copy.deepcopy(current_config)
|
||||
|
||||
# Require test_mode_config.test_start_date <= current_config.test_start_date
|
||||
loader_trainer_args = loader_mode_config.get("trainer", {}).get("args", {})
|
||||
cur_trainer_args = current_config.get("trainer", {}).get("args", {})
|
||||
loader_start_date = loader_trainer_args.pop("test_start_date")
|
||||
cur_test_start_date = cur_trainer_args.pop("test_start_date")
|
||||
assert (
|
||||
loader_start_date <= cur_test_start_date
|
||||
), "Require: loader_mode_config.test_start_date <= current_config.test_start_date"
|
||||
|
||||
# TODO: For the user's own extended `Trainer`, the support is not very good
|
||||
if "RollingTrainer" == current_config.get("trainer", {}).get("class", None):
|
||||
loader_period = loader_trainer_args.pop("rolling_period")
|
||||
cur_period = cur_trainer_args.pop("rolling_period")
|
||||
assert (
|
||||
loader_period == cur_period
|
||||
), "Require: loader_mode_config.rolling_period == current_config.rolling_period"
|
||||
|
||||
compare_section = ["trainer", "model", "data"]
|
||||
for section in compare_section:
|
||||
changes = compare_dict_value(loader_mode_config.get(section, {}), current_config.get(section, {}))
|
||||
if changes:
|
||||
logger.warning("Warning: Loader mode config and current config, `{}` are different:\n".format(section))
|
||||
@@ -1,290 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
import copy
|
||||
import json
|
||||
import yaml
|
||||
import pickle
|
||||
import gridfs
|
||||
import pymongo
|
||||
from pathlib import Path
|
||||
from abc import abstractmethod
|
||||
|
||||
from .config import EstimatorConfigManager, ExperimentConfig
|
||||
|
||||
|
||||
class Fetcher(object):
|
||||
"""Sacred Experiments Fetcher"""
|
||||
|
||||
@abstractmethod
|
||||
def _get_experiment(self, exp_name, exp_id):
|
||||
"""Get experiment basic info with experiment and experiment id
|
||||
|
||||
:param exp_name: experiment name
|
||||
:param exp_id: experiment id
|
||||
:return: dict
|
||||
Must contain keys: _id, experiment, info, stop_time.
|
||||
Here is an example below for FileFetcher.
|
||||
exp = {
|
||||
'_id': exp_id, # experiment id
|
||||
'path': path, # experiment result path
|
||||
'experiment': {'name': exp_name}, # experiment
|
||||
'info': info, # experiment config info
|
||||
'stop_time': run.get('stop_time', None) # The time the experiment ended
|
||||
}
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _list_experiments(self, exp_name=None):
|
||||
"""Get experiment basic info list with experiment name
|
||||
|
||||
:param exp_name: experiment name
|
||||
:return: list
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _iter_artifacts(self, experiment):
|
||||
"""Get information about the data in the experiment results
|
||||
|
||||
:param experiment: `self._get_experiment` method result
|
||||
:return: iterable
|
||||
Each element contains two elements.
|
||||
first element : data name
|
||||
second element : data uri
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _load_data(self, uri):
|
||||
"""Load data with uri
|
||||
|
||||
:param uri: data uri
|
||||
:return: bytes
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def model_dict_to_buffer_list(model_dict):
|
||||
"""
|
||||
|
||||
:param model_dict:
|
||||
:return:
|
||||
"""
|
||||
model_list = []
|
||||
is_static_model = False
|
||||
if len(model_dict) == 1 and list(model_dict.keys())[0] == "model.bin":
|
||||
is_static_model = True
|
||||
model_list.append(list(model_dict.values())[0])
|
||||
else:
|
||||
sep = "model.bin_"
|
||||
model_ids = list(map(lambda x: int(x.split(sep)[1]), model_dict.keys()))
|
||||
min_id, max_id = min(model_ids), max(model_ids)
|
||||
for i in range(min_id, max_id + 1):
|
||||
model_key = sep + str(i)
|
||||
model = model_dict.get(model_key, None)
|
||||
if model is None:
|
||||
print(
|
||||
"WARNING: In Fetcher, {} is missing when the get model is in the get_experiment function.".format(
|
||||
model_key
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
model_list.append(model)
|
||||
|
||||
if is_static_model:
|
||||
return model_list[0]
|
||||
|
||||
return model_list
|
||||
|
||||
def get_experiments(self, exp_name=None):
|
||||
"""Get experiments with name.
|
||||
|
||||
:param exp_name: str
|
||||
If `exp_name` is set to None, then all experiments will return.
|
||||
:return: dict
|
||||
Experiments info dict(Including experiment id and task_config to run the
|
||||
experiment). Here is an example below.
|
||||
{
|
||||
'a_experiment': [
|
||||
{
|
||||
'id': '1',
|
||||
'task_config': {...}
|
||||
},
|
||||
...
|
||||
]
|
||||
...
|
||||
}
|
||||
"""
|
||||
res = dict()
|
||||
for ex in self._list_experiments(exp_name):
|
||||
name = ex["experiment"]["name"]
|
||||
tmp = {
|
||||
"id": ex["_id"],
|
||||
"task_config": ex["info"].get("task_config", {}),
|
||||
"ex_run_stop_time": ex.get("stop_time", None),
|
||||
}
|
||||
res.setdefault(name, []).append(tmp)
|
||||
return res
|
||||
|
||||
def get_experiment(self, exp_name, exp_id, fields=None):
|
||||
"""
|
||||
|
||||
:param exp_name:
|
||||
:param exp_id:
|
||||
:param fields: list
|
||||
Experiment result fields, if fields is None, will get all fields.
|
||||
Currently supported fields:
|
||||
['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
|
||||
:return: dict
|
||||
"""
|
||||
fields = copy.copy(fields)
|
||||
ex = self._get_experiment(exp_name, exp_id)
|
||||
results = dict()
|
||||
model_dict = dict()
|
||||
for name, uri in self._iter_artifacts(ex):
|
||||
# When saving, use `sacred.experiment.add_artifact(filename)` , so `name` is os.path.basename(filename)
|
||||
prefix = name.split(".")[0]
|
||||
if fields and prefix not in fields:
|
||||
continue
|
||||
data = self._load_data(uri)
|
||||
if prefix == "model":
|
||||
model_dict[name] = data
|
||||
else:
|
||||
results[prefix] = pickle.loads(data)
|
||||
# Sort model
|
||||
if model_dict:
|
||||
results["model"] = self.model_dict_to_buffer_list(model_dict)
|
||||
|
||||
# Info
|
||||
results["task_config"] = ex["info"].get("task_config", {})
|
||||
return results
|
||||
|
||||
def estimator_config_to_dict(self, exp_name, exp_id):
|
||||
"""Save configuration to file
|
||||
|
||||
:param exp_name:
|
||||
:param exp_id:
|
||||
:return: config dict
|
||||
"""
|
||||
|
||||
return self.get_experiment(exp_name, exp_id, fields=["task_config"])["task_config"]
|
||||
|
||||
|
||||
class FileFetcher(Fetcher):
|
||||
"""File Fetcher"""
|
||||
|
||||
def __init__(self, experiments_dir):
|
||||
self.experiments_dir = Path(experiments_dir)
|
||||
|
||||
def _get_experiment(self, exp_name, exp_id):
|
||||
path = self.experiments_dir / exp_name / "sacred" / str(exp_id)
|
||||
info_path = path / "info.json"
|
||||
run_path = path / "run.json"
|
||||
|
||||
if info_path.exists():
|
||||
with info_path.open("r") as f:
|
||||
info = json.load(f)
|
||||
else:
|
||||
info = {}
|
||||
|
||||
if run_path.exists():
|
||||
with run_path.open("r") as f:
|
||||
run = json.load(f)
|
||||
else:
|
||||
run = {}
|
||||
|
||||
exp = {
|
||||
"_id": exp_id,
|
||||
"path": path,
|
||||
"experiment": {"name": exp_name},
|
||||
"info": info,
|
||||
"stop_time": run.get("stop_time", None),
|
||||
}
|
||||
return exp
|
||||
|
||||
def _list_experiments(self, exp_name=None):
|
||||
runs = []
|
||||
for path in self.experiments_dir.glob("{}/sacred/[!_]*".format(exp_name or "*")):
|
||||
exp_name, exp_id = path.parents[1].name, path.name
|
||||
runs.append(self._get_experiment(exp_name, exp_id))
|
||||
return runs
|
||||
|
||||
def _iter_artifacts(self, experiment):
|
||||
if experiment is None:
|
||||
return []
|
||||
|
||||
for fname in experiment["path"].iterdir():
|
||||
if fname.suffix == ".pkl" or ".bin" in fname.suffix:
|
||||
name, uri = fname.name, str(fname)
|
||||
yield name, uri
|
||||
|
||||
def _load_data(self, uri):
|
||||
with open(uri, "rb") as f:
|
||||
data = f.read()
|
||||
return data
|
||||
|
||||
|
||||
class MongoFetcher(Fetcher):
|
||||
"""MongoDB Fetcher"""
|
||||
|
||||
def __init__(self, mongo_url, db_name):
|
||||
self.mongo_url = mongo_url
|
||||
self.db_name = db_name
|
||||
self.client = None
|
||||
self.db = None
|
||||
self.runs = None
|
||||
self.fs = None
|
||||
self._setup_mongo_client()
|
||||
|
||||
def _setup_mongo_client(self):
|
||||
self.client = pymongo.MongoClient(self.mongo_url)
|
||||
self.db = self.client[self.db_name]
|
||||
self.runs = self.db.runs
|
||||
self.fs = gridfs.GridFS(self.db)
|
||||
|
||||
def _get_experiment(self, exp_name, exp_id):
|
||||
return self.runs.find_one({"_id": exp_id})
|
||||
|
||||
def _list_experiments(self, exp_name=None):
|
||||
if exp_name is None:
|
||||
return self.runs.find()
|
||||
return self.runs.find({"experiment.name": exp_name})
|
||||
|
||||
def _iter_artifacts(self, experiment):
|
||||
if experiment is None:
|
||||
return []
|
||||
for artifact in experiment.get("artifacts", []):
|
||||
name, uri = artifact["name"], artifact["file_id"]
|
||||
yield name, uri
|
||||
|
||||
def _load_data(self, uri):
|
||||
data = self.fs.get(uri).read()
|
||||
return data
|
||||
|
||||
|
||||
def create_fetcher_with_config(config_manager: EstimatorConfigManager, load_form_loader: bool = False):
|
||||
"""Create fetcher with loader config
|
||||
|
||||
:param config_manager:
|
||||
:param load_form_loader
|
||||
:return:
|
||||
"""
|
||||
flag = ""
|
||||
if load_form_loader:
|
||||
flag = "loader_"
|
||||
if config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_FILE_STORAGE:
|
||||
return FileFetcher(eval("config_manager.ex_config.{}_dir".format("loader" if load_form_loader else "global")))
|
||||
elif config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
|
||||
return MongoFetcher(
|
||||
mongo_url=eval("config_manager.ex_config.{}mongo_url".format(flag)),
|
||||
db_name=eval("config_manager.ex_config.{}db_name".format(flag)),
|
||||
)
|
||||
else:
|
||||
return NotImplementedError("Unkown Backend")
|
||||
@@ -1,115 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
from ... import init
|
||||
from .config import EstimatorConfigManager
|
||||
from ...log import get_module_logger
|
||||
from sacred import Experiment
|
||||
from sacred.observers import FileStorageObserver
|
||||
from sacred.observers import MongoObserver
|
||||
|
||||
args_parser = argparse.ArgumentParser(prog="estimator")
|
||||
args_parser.add_argument(
|
||||
"-c",
|
||||
"--config_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="json config path indicates where to load config.",
|
||||
)
|
||||
|
||||
args = args_parser.parse_args()
|
||||
|
||||
|
||||
class SacredExperiment(object):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name,
|
||||
experiment_dir,
|
||||
observer_type="file_storage",
|
||||
mongo_url=None,
|
||||
db_name=None,
|
||||
):
|
||||
"""__init__
|
||||
|
||||
:param experiment_name: The name of the experiments.
|
||||
:param experiment_dir: The directory to store all the results of the experiments(This is for file_storage).
|
||||
:param observer_type: The observer to record the results: the `file_storage` or `mongo`
|
||||
:param mongo_url: The mongo url(for mongo observer)
|
||||
:param db_name: The mongo url(for mongo observer)
|
||||
"""
|
||||
self.experiment_name = experiment_name
|
||||
self.experiment = Experiment(self.experiment_name)
|
||||
self.experiment_dir = experiment_dir
|
||||
self.experiment.logger = get_module_logger("Sacred")
|
||||
|
||||
self.observer_type = observer_type
|
||||
self.mongo_db_url = mongo_url
|
||||
self.mongo_db_name = db_name
|
||||
|
||||
self._setup_experiment()
|
||||
|
||||
def _setup_experiment(self):
|
||||
if self.observer_type == "file_storage":
|
||||
file_storage_observer = FileStorageObserver.create(basedir=self.experiment_dir)
|
||||
self.experiment.observers.append(file_storage_observer)
|
||||
elif self.observer_type == "mongo":
|
||||
mongo_observer = MongoObserver.create(url=self.mongo_db_url, db_name=self.mongo_db_name)
|
||||
self.experiment.observers.append(mongo_observer)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported observer type: {}".format(self.observer_type))
|
||||
|
||||
def add_artifact(self, filename):
|
||||
self.experiment.add_artifact(filename)
|
||||
|
||||
def add_info(self, key, value):
|
||||
self.experiment.info[key] = value
|
||||
|
||||
def main_wrapper(self, func):
|
||||
return self.experiment.main(func)
|
||||
|
||||
def config_wrapper(self, func):
|
||||
return self.experiment.config(func)
|
||||
|
||||
|
||||
CONFIG_MANAGER = EstimatorConfigManager(args.config_path)
|
||||
|
||||
ex = SacredExperiment(
|
||||
CONFIG_MANAGER.ex_config.name,
|
||||
CONFIG_MANAGER.ex_config.sacred_dir,
|
||||
observer_type=CONFIG_MANAGER.ex_config.observer_type,
|
||||
mongo_url=CONFIG_MANAGER.ex_config.mongo_url,
|
||||
db_name=CONFIG_MANAGER.ex_config.db_name,
|
||||
)
|
||||
|
||||
# qlib init
|
||||
init(
|
||||
provider_uri=CONFIG_MANAGER.qlib_data_config.provider_uri,
|
||||
mount_path=CONFIG_MANAGER.qlib_data_config.mount_path,
|
||||
auto_mount=CONFIG_MANAGER.qlib_data_config.auto_mount,
|
||||
region=CONFIG_MANAGER.qlib_data_config.region,
|
||||
**CONFIG_MANAGER.qlib_data_config.args
|
||||
)
|
||||
|
||||
|
||||
@ex.main_wrapper
|
||||
def _main():
|
||||
# 1. Get estimator class.
|
||||
estimator_class = getattr(
|
||||
importlib.import_module(".estimator", package="qlib.contrib.estimator"),
|
||||
"Estimator",
|
||||
)
|
||||
# 2. Init estimator.
|
||||
estimator = estimator_class(CONFIG_MANAGER, ex)
|
||||
estimator.run()
|
||||
|
||||
|
||||
def run():
|
||||
ex.experiment.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -1,317 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from scipy.stats import pearsonr
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from .launcher import CONFIG_MANAGER
|
||||
from .fetcher import create_fetcher_with_config
|
||||
from ...utils import drop_nan_by_y_index, transform_end_date
|
||||
|
||||
|
||||
class BaseTrainer(object):
|
||||
def __init__(self, model_class, model_save_path, model_args, data_handler: DataHandlerLP, sacred_ex, **kwargs):
|
||||
# 1. Model.
|
||||
self.model_class = model_class
|
||||
self.model_save_path = model_save_path
|
||||
self.model_args = model_args
|
||||
|
||||
# 2. Data handler.
|
||||
self.data_handler = data_handler
|
||||
|
||||
# 3. Sacred ex.
|
||||
self.ex = sacred_ex
|
||||
|
||||
# 4. Logger.
|
||||
self.logger = get_module_logger("Trainer")
|
||||
|
||||
# 5. Data time
|
||||
self.train_start_date = kwargs.get("train_start_date", None)
|
||||
self.train_end_date = kwargs.get("train_end_date", None)
|
||||
self.validate_start_date = kwargs.get("validate_start_date", None)
|
||||
self.validate_end_date = kwargs.get("validate_end_date", None)
|
||||
self.test_start_date = kwargs.get("test_start_date", None)
|
||||
self.test_end_date = transform_end_date(kwargs.get("test_end_date", None))
|
||||
|
||||
@abstractmethod
|
||||
def train(self):
|
||||
"""
|
||||
Implement this method indicating how to train a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
"""
|
||||
Implement this method indicating how to restore a model and the data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_test_pred(self):
|
||||
"""
|
||||
Implement this method indicating how to get prediction result(s) from a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_test_performance(self):
|
||||
"""
|
||||
Implement this method indicating how to get the performance of the model.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement `get_test_performance`")
|
||||
|
||||
def get_test_score(self):
|
||||
"""
|
||||
Override this method to transfer the predict result(s) into the score of the stock.
|
||||
Note: If this is a multi-label training, you need to transfer predict labels into one score.
|
||||
Or you can just use the result of `get_test_pred()` (you can also process the result) if this is one label training.
|
||||
We use the first column of the result of `get_test_pred()` as default method (regard it as one label training).
|
||||
"""
|
||||
pred = self.get_test_pred()
|
||||
pred_score = pd.DataFrame(index=pred.index)
|
||||
pred_score["score"] = pred.iloc(axis=1)[0]
|
||||
return pred_score
|
||||
|
||||
|
||||
class StaticTrainer(BaseTrainer):
|
||||
def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
|
||||
super(StaticTrainer, self).__init__(model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs)
|
||||
self.model = None
|
||||
|
||||
split_data = self.data_handler.get_split_data(
|
||||
self.train_start_date,
|
||||
self.train_end_date,
|
||||
self.validate_start_date,
|
||||
self.validate_end_date,
|
||||
self.test_start_date,
|
||||
self.test_end_date,
|
||||
)
|
||||
(
|
||||
self.x_train,
|
||||
self.y_train,
|
||||
self.x_validate,
|
||||
self.y_validate,
|
||||
self.x_test,
|
||||
self.y_test,
|
||||
) = split_data
|
||||
|
||||
def train(self):
|
||||
TimeInspector.set_time_mark()
|
||||
model = self.model_class(**self.model_args)
|
||||
|
||||
if CONFIG_MANAGER.ex_config.finetune:
|
||||
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
|
||||
loader_model = fetcher.get_experiment(
|
||||
exp_name=CONFIG_MANAGER.ex_config.loader_name,
|
||||
exp_id=CONFIG_MANAGER.ex_config.loader_id,
|
||||
fields=["model"],
|
||||
)["model"]
|
||||
|
||||
if isinstance(loader_model, list):
|
||||
model_index = (
|
||||
-1
|
||||
if CONFIG_MANAGER.ex_config.loader_model_index is None
|
||||
else CONFIG_MANAGER.ex_config.loader_model_index
|
||||
)
|
||||
loader_model = loader_model[model_index]
|
||||
|
||||
model.load(loader_model)
|
||||
model.finetune(self.x_train, self.y_train, self.x_validate, self.y_validate)
|
||||
else:
|
||||
model.fit(self.x_train, self.y_train, self.x_validate, self.y_validate)
|
||||
model.save(self.model_save_path)
|
||||
self.ex.add_artifact(self.model_save_path)
|
||||
self.model = model
|
||||
TimeInspector.log_cost_time("Finished training model.")
|
||||
|
||||
def load(self):
|
||||
model = self.model_class(**self.model_args)
|
||||
|
||||
# Load model
|
||||
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
|
||||
loader_model = fetcher.get_experiment(
|
||||
exp_name=CONFIG_MANAGER.ex_config.loader_name,
|
||||
exp_id=CONFIG_MANAGER.ex_config.loader_id,
|
||||
fields=["model"],
|
||||
)["model"]
|
||||
|
||||
if isinstance(loader_model, list):
|
||||
model_index = (
|
||||
-1
|
||||
if CONFIG_MANAGER.ex_config.loader_model_index is None
|
||||
else CONFIG_MANAGER.ex_config.loader_model_index
|
||||
)
|
||||
loader_model = loader_model[model_index]
|
||||
|
||||
model.load(loader_model)
|
||||
|
||||
# Save model, after load, if you don't save the model, the result of this experiment will be no model
|
||||
model.save(self.model_save_path)
|
||||
self.ex.add_artifact(self.model_save_path)
|
||||
self.model = model
|
||||
|
||||
def get_test_pred(self):
|
||||
pred = self.model.predict(self.x_test)
|
||||
pred = pd.DataFrame(pred, index=self.x_test.index, columns=self.y_test.columns)
|
||||
return pred
|
||||
|
||||
def get_test_performance(self):
|
||||
try:
|
||||
model_score = self.model.score(self.x_test, self.y_test)
|
||||
except NotImplementedError:
|
||||
model_score = None
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
x_test, y_test, __ = drop_nan_by_y_index(self.x_test, self.y_test)
|
||||
pred_test = self.model.predict(x_test)
|
||||
model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
|
||||
|
||||
performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
|
||||
return performance
|
||||
|
||||
|
||||
class RollingTrainer(BaseTrainer):
|
||||
def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
|
||||
super(RollingTrainer, self).__init__(
|
||||
model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs
|
||||
)
|
||||
self.rolling_period = kwargs.get("rolling_period", 60)
|
||||
self.models = []
|
||||
self.rolling_data = []
|
||||
self.all_x_test = []
|
||||
self.all_y_test = []
|
||||
for data in self.data_handler.get_rolling_data(
|
||||
self.train_start_date,
|
||||
self.train_end_date,
|
||||
self.validate_start_date,
|
||||
self.validate_end_date,
|
||||
self.test_start_date,
|
||||
self.test_end_date,
|
||||
self.rolling_period,
|
||||
):
|
||||
self.rolling_data.append(data)
|
||||
__, __, __, __, x_test, y_test = data
|
||||
self.all_x_test.append(x_test)
|
||||
self.all_y_test.append(y_test)
|
||||
|
||||
def train(self):
|
||||
# 1. Get total data parts.
|
||||
# total_data_parts = self.data_handler.total_data_parts
|
||||
# self.logger.warning('Total numbers of model are: {}, start training models...'.format(total_data_parts))
|
||||
if CONFIG_MANAGER.ex_config.finetune:
|
||||
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
|
||||
loader_model = fetcher.get_experiment(
|
||||
exp_name=CONFIG_MANAGER.ex_config.loader_name,
|
||||
exp_id=CONFIG_MANAGER.ex_config.loader_id,
|
||||
fields=["model"],
|
||||
)["model"]
|
||||
loader_model_index = CONFIG_MANAGER.ex_config.loader_model_index
|
||||
previous_model_path = ""
|
||||
# 2. Rolling train.
|
||||
for (
|
||||
index,
|
||||
(x_train, y_train, x_validate, y_validate, x_test, y_test),
|
||||
) in enumerate(self.rolling_data):
|
||||
TimeInspector.set_time_mark()
|
||||
model = self.model_class(**self.model_args)
|
||||
|
||||
if CONFIG_MANAGER.ex_config.finetune:
|
||||
# Finetune model
|
||||
if loader_model_index is None and isinstance(loader_model, list):
|
||||
try:
|
||||
model.load(loader_model[index])
|
||||
except IndexError:
|
||||
# Load model by previous_model_path
|
||||
with open(previous_model_path, "rb") as fp:
|
||||
model.load(fp)
|
||||
model.finetune(x_train, y_train, x_validate, y_validate)
|
||||
else:
|
||||
|
||||
if index == 0:
|
||||
loader_model = (
|
||||
loader_model[loader_model_index] if isinstance(loader_model, list) else loader_model
|
||||
)
|
||||
model.load(loader_model)
|
||||
else:
|
||||
with open(previous_model_path, "rb") as fp:
|
||||
model.load(fp)
|
||||
|
||||
model.finetune(x_train, y_train, x_validate, y_validate)
|
||||
|
||||
else:
|
||||
model.fit(x_train, y_train, x_validate, y_validate)
|
||||
|
||||
model_save_path = "{}_{}".format(self.model_save_path, index)
|
||||
model.save(model_save_path)
|
||||
previous_model_path = model_save_path
|
||||
self.ex.add_artifact(model_save_path)
|
||||
self.models.append(model)
|
||||
TimeInspector.log_cost_time("Finished training model: {}.".format(index + 1))
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
Load the data and the model
|
||||
"""
|
||||
fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
|
||||
loader_model = fetcher.get_experiment(
|
||||
exp_name=CONFIG_MANAGER.ex_config.loader_name,
|
||||
exp_id=CONFIG_MANAGER.ex_config.loader_id,
|
||||
fields=["model"],
|
||||
)["model"]
|
||||
for index in range(len(self.all_x_test)):
|
||||
model = self.model_class(**self.model_args)
|
||||
|
||||
model.load(loader_model[index])
|
||||
|
||||
# Save model
|
||||
model_save_path = "{}_{}".format(self.model_save_path, index)
|
||||
model.save(model_save_path)
|
||||
self.ex.add_artifact(model_save_path)
|
||||
|
||||
self.models.append(model)
|
||||
|
||||
def get_test_pred(self):
|
||||
"""
|
||||
Predict the score on test data with the models.
|
||||
Please ensure the models and data are loaded before call this score.
|
||||
|
||||
:return: the predicted scores for the pred
|
||||
"""
|
||||
pred_df_list = []
|
||||
y_test_columns = self.all_y_test[0].columns
|
||||
# Start iteration.
|
||||
for model, x_test in zip(self.models, self.all_x_test):
|
||||
pred = model.predict(x_test)
|
||||
pred_df = pd.DataFrame(pred, index=x_test.index, columns=y_test_columns)
|
||||
pred_df_list.append(pred_df)
|
||||
return pd.concat(pred_df_list)
|
||||
|
||||
def get_test_performance(self):
|
||||
"""
|
||||
Get the performances of the models
|
||||
|
||||
:return: the performances of models
|
||||
"""
|
||||
pred_test_list = []
|
||||
y_test_list = []
|
||||
scorer = self.models[0]._scorer
|
||||
for model, x_test, y_test in zip(self.models, self.all_x_test, self.all_y_test):
|
||||
# Remove rows from x, y and w, which contain Nan in any columns in y_test.
|
||||
x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
|
||||
pred_test_list.append(model.predict(x_test))
|
||||
y_test_list.append(np.squeeze(y_test.values))
|
||||
|
||||
pred_test_array = np.concatenate(pred_test_list, axis=0)
|
||||
y_test_array = np.concatenate(y_test_list, axis=0)
|
||||
|
||||
model_score = scorer(y_test_array, pred_test_array)
|
||||
model_pearsonr = pearsonr(np.ravel(y_test_array), np.ravel(pred_test_array))[0]
|
||||
|
||||
performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
|
||||
return performance
|
||||
@@ -26,9 +26,9 @@ def risk_analysis(r, N=252):
|
||||
Parameters
|
||||
----------
|
||||
r : pandas.Series
|
||||
daily return series
|
||||
daily return series.
|
||||
N: int
|
||||
scaler for annualizing information_ratio (day: 250, week: 50, month: 12)
|
||||
scaler for annualizing information_ratio (day: 250, week: 50, month: 12).
|
||||
"""
|
||||
mean = r.mean()
|
||||
std = r.std(ddof=1)
|
||||
@@ -61,7 +61,7 @@ def get_strategy(
|
||||
----------
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
@@ -73,14 +73,14 @@ def get_strategy(
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit)
|
||||
sell_limit should be no less than topk
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -126,21 +126,21 @@ def get_exchange(
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange()
|
||||
exchange: Exchange().
|
||||
subscribe_fields: list
|
||||
subscribe fields
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
NOTE: This will be faster with offline qlib.
|
||||
@@ -193,20 +193,20 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
|
||||
- **backtest workflow related or commmon arguments**
|
||||
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
predict should has <datetime, instrument> index and one `score` column.
|
||||
account : float
|
||||
init account value
|
||||
init account value.
|
||||
shift : int
|
||||
whether to shift prediction by one day
|
||||
whether to shift prediction by one day.
|
||||
benchmark : str
|
||||
benchmark code, default is SH000905 CSI 500
|
||||
benchmark code, default is SH000905 CSI 500.
|
||||
verbose : bool
|
||||
whether to print log
|
||||
whether to print log.
|
||||
|
||||
- **strategy related arguments**
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
@@ -218,33 +218,33 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit)
|
||||
sell_limit should be no less than topk
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
- **exchange related arguments**
|
||||
|
||||
|
||||
exchange: Exchange()
|
||||
pass the exchange for speeding up.
|
||||
subscribe_fields: list
|
||||
subscribe fields
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost. The default value is 0.002(0.2%).
|
||||
close_cost : float
|
||||
close transaction cost. The default value is 0.002(0.2%).
|
||||
min_cost : float
|
||||
min transaction cost
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
@@ -291,17 +291,17 @@ def long_short_backtest(
|
||||
"""
|
||||
A backtest for long-short strategy
|
||||
|
||||
:param pred: The trading signal produced on day `T`
|
||||
:param topk: The short topk securities and long topk securities
|
||||
:param deal_price: The price to deal the trading
|
||||
:param pred: The trading signal produced on day `T`.
|
||||
:param topk: The short topk securities and long topk securities.
|
||||
:param deal_price: The price to deal the trading.
|
||||
:param shift: Whether to shift prediction by one day. The trading day will be T+1 if shift==1.
|
||||
:param open_cost: open transaction cost
|
||||
:param close_cost: close transaction cost
|
||||
:param trade_unit: 100 for China A
|
||||
:param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit
|
||||
:param min_cost: min transaction cost
|
||||
:param subscribe_fields: subscribe fields
|
||||
:param extract_codes: bool
|
||||
:param open_cost: open transaction cost.
|
||||
:param close_cost: close transaction cost.
|
||||
:param trade_unit: 100 for China A.
|
||||
:param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit.
|
||||
:param min_cost: min transaction cost.
|
||||
:param subscribe_fields: subscribe fields.
|
||||
:param extract_codes: bool.
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
NOTE: This will be faster with offline qlib.
|
||||
:return: The result of backtest, it is represented by a dict.
|
||||
|
||||
@@ -1,3 +1,15 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from catboost import Pool, CatBoost
|
||||
|
||||
349
qlib/contrib/model/pytorch_alstm.py
Normal file
@@ -0,0 +1,349 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class ALSTM(Model):
|
||||
"""ALSTM Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
GPU="0",
|
||||
seed=0,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ALSTM")
|
||||
self.logger.info("ALSTM pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.visible_GPU = GPU
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"ALSTM parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
self.ALSTM_model = ALSTMModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
if self.use_gpu:
|
||||
self.ALSTM_model.cuda()
|
||||
# set the visible GPU
|
||||
if self.visible_GPU:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.ALSTM_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float()
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
feature = feature.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.ALSTM_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.ALSTM_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float()
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
feature = feature.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.ALSTM_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.ALSTM_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
self.ALSTM_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
x_batch = x_batch.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(x_batch).detach().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class ALSTMModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type="GRU"):
|
||||
super().__init__()
|
||||
self.hid_size = hidden_size
|
||||
self.input_size = d_feat
|
||||
self.dropout = dropout
|
||||
self.rnn_type = rnn_type
|
||||
self.rnn_layer = num_layers
|
||||
self._build_model()
|
||||
|
||||
def _build_model(self):
|
||||
try:
|
||||
klass = getattr(nn, self.rnn_type.upper())
|
||||
except:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
|
||||
self.net = nn.Sequential()
|
||||
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
|
||||
self.net.add_module("act", nn.Tanh())
|
||||
self.rnn = klass(
|
||||
input_size=self.hid_size,
|
||||
hidden_size=self.hid_size,
|
||||
num_layers=self.rnn_layer,
|
||||
batch_first=True,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.fc_out = nn.Linear(in_features=self.hid_size * 2, out_features=1)
|
||||
self.att_net = nn.Sequential()
|
||||
self.att_net.add_module("att_fc_in", nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)))
|
||||
self.att_net.add_module("att_dropout", torch.nn.Dropout(self.dropout))
|
||||
self.att_net.add_module("att_act", nn.Tanh())
|
||||
self.att_net.add_module("att_fc_out", nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False))
|
||||
self.att_net.add_module("att_softmax", nn.Softmax(dim=1))
|
||||
|
||||
def forward(self, inputs):
|
||||
# inputs: [batch_size, input_size*input_day]
|
||||
inputs = inputs.view(len(inputs), self.input_size, -1)
|
||||
inputs = inputs.permute(0, 2, 1) # [batch, input_size, seq_len] -> [batch, seq_len, input_size]
|
||||
rnn_out, _ = self.rnn(self.net(inputs)) # [batch, seq_len, num_directions * hidden_size]
|
||||
attention_score = self.att_net(rnn_out) # [batch, seq_len, 1]
|
||||
out_att = torch.mul(rnn_out, attention_score)
|
||||
out_att = torch.sum(out_att, dim=1)
|
||||
out = self.fc_out(
|
||||
torch.cat((rnn_out[:, -1, :], out_att), dim=1)
|
||||
) # [batch, seq_len, num_directions * hidden_size] -> [batch, 1]
|
||||
return out[..., 0]
|
||||
@@ -9,10 +9,8 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import create_save_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -28,14 +26,12 @@ class GAT(Model):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
input dimension
|
||||
output_dim : int
|
||||
output dimension
|
||||
layers : tuple
|
||||
layer sizes
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
@@ -50,8 +46,7 @@ class GAT(Model):
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="IC",
|
||||
batch_size=2000,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
@@ -73,7 +68,6 @@ class GAT(Model):
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
@@ -92,7 +86,6 @@ class GAT(Model):
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
@@ -108,7 +101,6 @@ class GAT(Model):
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
@@ -120,10 +112,6 @@ class GAT(Model):
|
||||
)
|
||||
)
|
||||
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.GAT_model = GATModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -160,34 +148,37 @@ class GAT(Model):
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
if self.metric == "IC":
|
||||
return self.cal_ic(pred[mask], label[mask])
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def cal_ic(self, pred, label):
|
||||
return torch.mean(pred * label)
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily inter as daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle the daily inter data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values) * 100
|
||||
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
self.GAT_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
# organize the train data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float()
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float()
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float()
|
||||
label = torch.from_numpy(y_train_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
feature = feature.cuda()
|
||||
@@ -212,16 +203,13 @@ class GAT(Model):
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
np.random.shuffle(indices)
|
||||
# organize the test data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float()
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float()
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float()
|
||||
label = torch.from_numpy(y_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
feature = feature.cuda()
|
||||
@@ -254,7 +242,6 @@ class GAT(Model):
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
@@ -265,12 +252,14 @@ class GAT(Model):
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
|
||||
pretrained_model = LSTMModel()
|
||||
pretrained_model.load_state_dict(torch.load('benchmarks/LSTM/model_lstm_csi300.pkl'))
|
||||
pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
|
||||
elif self.base_model == "GRU":
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
pretrained_model = GRUModel()
|
||||
pretrained_model.load_state_dict(torch.load('benchmarks/GRU/model_gru_csi300.pkl'))
|
||||
pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
@@ -319,17 +308,14 @@ class GAT(Model):
|
||||
index = x_test.index
|
||||
self.GAT_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
# organize the data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float()
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
x_batch = x_batch.cuda()
|
||||
@@ -375,7 +361,6 @@ class GATModel(nn.Module):
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
self.d_feat = d_feat
|
||||
|
||||
def cal_convariance(self, x, y): # the 2nd dimension of x and y are the same
|
||||
@@ -394,12 +379,7 @@ class GATModel(nn.Module):
|
||||
out, _ = self.rnn(x)
|
||||
hidden = out[:, -1, :]
|
||||
hidden = self.bn1(hidden)
|
||||
|
||||
gamma = self.cal_convariance(hidden, hidden)
|
||||
# gamma = hidden.mm(torch.t(hidden))
|
||||
# gamma = self.leaky_relu(gamma)
|
||||
# gamma = self.softmax(gamma)
|
||||
# gamma = gamma * (torch.ones(x.shape[0], x.shape[0]).to(device) - torch.diag(torch.ones(x.shape[0])).to(device))
|
||||
output = gamma.mm(hidden)
|
||||
output = self.fc(output)
|
||||
output = self.bn2(output)
|
||||
|
||||
@@ -28,14 +28,10 @@ class GRU(Model):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
input dimension
|
||||
output_dim : int
|
||||
output dimension
|
||||
layers : tuple
|
||||
layer sizes
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
@@ -50,7 +46,7 @@ class GRU(Model):
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="IC",
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
@@ -112,10 +108,6 @@ class GRU(Model):
|
||||
)
|
||||
)
|
||||
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.gru_model = GRUModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
|
||||
)
|
||||
@@ -148,21 +140,16 @@ class GRU(Model):
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
if self.metric == "IC":
|
||||
return self.cal_ic(pred[mask], label[mask])
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def cal_ic(self, pred, label):
|
||||
return torch.mean(pred * label)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values) * 100
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.gru_model.train()
|
||||
|
||||
@@ -201,7 +188,6 @@ class GRU(Model):
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
@@ -251,7 +237,6 @@ class GRU(Model):
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
# return
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
|
||||
491
qlib/contrib/model/pytorch_hats.py
Normal file
@@ -0,0 +1,491 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import create_save_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class HATS(Model):
|
||||
"""HATS Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.5,
|
||||
n_epochs=200,
|
||||
lr=0.01,
|
||||
metric="",
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
base_model="GRU",
|
||||
with_pretrain=True,
|
||||
optimizer="adam",
|
||||
GPU="0",
|
||||
seed=0,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("HATS")
|
||||
self.logger.info("HATS pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.visible_GPU = GPU
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"HATS parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
base_model,
|
||||
with_pretrain,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
self.HATS_model = HATSModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.HATS_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.HATS_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
if self.use_gpu:
|
||||
self.HATS_model.cuda()
|
||||
# set the visible GPU
|
||||
if self.visible_GPU:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily inter as daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle the daily inter data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.HATS_model.train()
|
||||
|
||||
# organize the train data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_train_values[batch]).float()
|
||||
label = torch.from_numpy(y_train_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
feature = feature.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
pred = self.HATS_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.HATS_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare testing data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.HATS_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
# organize the test data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float()
|
||||
label = torch.from_numpy(y_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
feature = feature.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
pred = self.HATS_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# load pretrained base_model
|
||||
if self.with_pretrain:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
if self.base_model == "LSTM":
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
|
||||
pretrained_model = LSTMModel()
|
||||
pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
|
||||
elif self.base_model == "GRU":
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
||||
pretrained_model = GRUModel()
|
||||
pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
|
||||
model_dict = self.HATS_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.HATS_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.HATS_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.HATS_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
self.HATS_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
# organize the data into daily inter as daily batches
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float()
|
||||
|
||||
if self.use_gpu:
|
||||
x_batch = x_batch.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.HATS_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.HATS_model(x_batch).detach().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class HATSModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
|
||||
super().__init__()
|
||||
|
||||
if base_model == "GRU":
|
||||
self.model = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.model = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.bn1 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
|
||||
self.fc = nn.Linear(hidden_size, hidden_size)
|
||||
self.bn2 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
self.leaky_relu = nn.LeakyReLU()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
self.d_feat = d_feat
|
||||
|
||||
num_head_att = [1] * num_layers
|
||||
hidden_dim = [hidden_size] * num_layers
|
||||
dims = [d_feat] + [d * nh for (d, nh) in zip(hidden_dim, num_head_att[:-1])] + [num_head_att[-1]]
|
||||
in_dims = dims[:-1]
|
||||
out_dims = [d // nh for (d, nh) in zip(dims[1:], num_head_att)]
|
||||
self.attn = nn.ModuleList(
|
||||
[GraphAttention(i, o, nh, dropout) for (i, o, nh) in zip(in_dims, out_dims, num_head_att)]
|
||||
)
|
||||
self.bns = nn.ModuleList([nn.BatchNorm1d(dim) for dim in dims[1:-1]])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.elu = nn.ELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.model(x)
|
||||
hidden = out[:, -1, :]
|
||||
hidden = self.bn1(hidden)
|
||||
attention = GraphAttention.cal_attention(hidden, hidden)
|
||||
output = attention.mm(hidden)
|
||||
output = self.fc(output)
|
||||
output = self.bn2(output)
|
||||
output = self.leaky_relu(output)
|
||||
return self.fc_out(output).squeeze()
|
||||
|
||||
|
||||
class GraphAttention(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, num_heads, dropout=0.5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
Dimension of input node features.
|
||||
output_dim : int
|
||||
Dimension of output node features.
|
||||
num_heads : list of ints
|
||||
Number of attention heads in each hidden layer and output layer. Must be non empty. Note that len(num_heads) = len(hidden_dims)+1.
|
||||
dropout : float
|
||||
Dropout rate. Default: 0.5.
|
||||
"""
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.fcs = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
|
||||
self.a = nn.ModuleList([nn.Linear(2 * output_dim, 1) for _ in range(num_heads)])
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.softmax = nn.Softmax(dim=0)
|
||||
self.leakyrelu = nn.LeakyReLU()
|
||||
|
||||
def forward(self, features, nodes, mappings, rows):
|
||||
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
An (n' x input_dim) tensor of input node features.
|
||||
nodes : list of numpy array
|
||||
nodes[i] is an array of the nodes in the ith layer of the
|
||||
computation graph.
|
||||
mappings : list of dictionary
|
||||
mappings[i] is a dictionary mappings node v (labelled 0 to |V|-1)
|
||||
in nodes[i] to its position in nodes[i]. For example,
|
||||
if nodes[i] = [2,5], then mappings[i][2] = 0 and
|
||||
mappings[i][5] = 1.
|
||||
rows : numpy array
|
||||
rows[i] is an array of neighbors of node i.
|
||||
Returns
|
||||
-------
|
||||
out : torch.Tensor
|
||||
An (len(node_layers[-1]) x output_dim) tensor of output node features.
|
||||
"""
|
||||
|
||||
nprime = features.shape[0]
|
||||
rows = [np.array([mappings[v] for v in row], dtype=np.int64) for row in rows]
|
||||
sum_degs = np.hstack(([0], np.cumsum([len(row) for row in rows])))
|
||||
mapped_nodes = [mappings[v] for v in nodes]
|
||||
indices = torch.LongTensor([[v, c] for (v, row) in zip(mapped_nodes, rows) for c in row]).t()
|
||||
|
||||
out = []
|
||||
for k in range(self.num_heads):
|
||||
h = self.fcs[k](features)
|
||||
|
||||
nbr_h = torch.cat(tuple([h[row] for row in rows]), dim=0)
|
||||
self_h = torch.cat(
|
||||
tuple([h[mappings[nodes[i]]].repeat(len(row), 1) for (i, row) in enumerate(rows)]), dim=0
|
||||
)
|
||||
cat_h = torch.cat((self_h, nbr_h), dim=1)
|
||||
|
||||
e = self.leakyrelu(self.a[k](cat_h))
|
||||
|
||||
alpha = [self.softmax(e[lo:hi]) for (lo, hi) in zip(sum_degs, sum_degs[1:])]
|
||||
alpha = torch.cat(tuple(alpha), dim=0)
|
||||
alpha = alpha.squeeze(1)
|
||||
alpha = self.dropout(alpha)
|
||||
|
||||
adj = torch.sparse.FloatTensor(indices, alpha, torch.Size([nprime, nprime]))
|
||||
out.append(torch.sparse.mm(adj, h)[mapped_nodes])
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def cal_attention(x, y):
|
||||
att_x = torch.mean(x, dim=1).reshape(-1, 1)
|
||||
att_y = torch.mean(y, dim=1).reshape(-1, 1)
|
||||
att = att_x.mm(torch.t(att_y))
|
||||
return (
|
||||
torch.mean(
|
||||
x.reshape(x.shape[0], 1, x.shape[1]).repeat(1, y.shape[0], 1)
|
||||
* y.reshape(1, y.shape[0], y.shape[1]).repeat(x.shape[0], 1, 1),
|
||||
dim=2,
|
||||
)
|
||||
- att
|
||||
)
|
||||
@@ -28,14 +28,10 @@ class LSTM(Model):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dim : int
|
||||
input dimension
|
||||
output_dim : int
|
||||
output dimension
|
||||
layers : tuple
|
||||
layer sizes
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
@@ -50,7 +46,7 @@ class LSTM(Model):
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="IC",
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
@@ -112,10 +108,6 @@ class LSTM(Model):
|
||||
)
|
||||
)
|
||||
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.lstm_model = LSTMModel(
|
||||
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
|
||||
)
|
||||
@@ -148,21 +140,16 @@ class LSTM(Model):
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
if self.metric == "IC":
|
||||
return self.cal_ic(pred[mask], label[mask])
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def cal_ic(self, pred, label):
|
||||
return torch.mean(pred * label)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values) * 100
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.lstm_model.train()
|
||||
|
||||
@@ -201,7 +188,6 @@ class LSTM(Model):
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
@@ -251,7 +237,6 @@ class LSTM(Model):
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
# return
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
@@ -90,10 +100,7 @@ class SFM_Model(nn.Module):
|
||||
x_c = torch.matmul(x * B_W[0], self.W_c) + self.b_c
|
||||
x_o = torch.matmul(x * B_W[0], self.W_o) + self.b_o
|
||||
|
||||
i = self.inner_activation(
|
||||
x_i + torch.matmul(h_tm1 * B_U[0], self.U_i)
|
||||
) # not sure whether I am doing in the right unsquuze
|
||||
|
||||
i = self.inner_activation(x_i + torch.matmul(h_tm1 * B_U[0], self.U_i))
|
||||
ste = self.inner_activation(x_ste + torch.matmul(h_tm1 * B_U[0], self.U_ste))
|
||||
fre = self.inner_activation(x_fre + torch.matmul(h_tm1 * B_U[0], self.U_fre))
|
||||
|
||||
@@ -173,10 +180,6 @@ class SFM(Model):
|
||||
output dimension
|
||||
lr : float
|
||||
learning rate
|
||||
lr_decay : float
|
||||
learning rate decay
|
||||
lr_decay_steps : int
|
||||
learning rate decay steps
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
@@ -193,12 +196,11 @@ class SFM(Model):
|
||||
dropout_U=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
eval_steps=5,
|
||||
loss="mse",
|
||||
lr_decay=0.96,
|
||||
lr_decay_steps=100,
|
||||
optimizer="gd",
|
||||
GPU="0",
|
||||
seed=0,
|
||||
@@ -217,13 +219,12 @@ class SFM(Model):
|
||||
self.dropout_U = dropout_U
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.eval_steps = eval_steps
|
||||
self.lr_decay = lr_decay
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
self.loss = loss
|
||||
self.device = "cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu"
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
@@ -232,16 +233,16 @@ class SFM(Model):
|
||||
"SFM parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\noutput_size : {}"
|
||||
"\nfrequency_dimension : {}"
|
||||
"\ndropout_W: {}"
|
||||
"\ndropout_U: {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\neval_steps : {}"
|
||||
"\nlr_decay : {}"
|
||||
"\nlr_decay_steps : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
@@ -249,16 +250,16 @@ class SFM(Model):
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
output_dim,
|
||||
freq_dim,
|
||||
dropout_W,
|
||||
dropout_U,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
eval_steps,
|
||||
lr_decay,
|
||||
lr_decay_steps,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
@@ -267,10 +268,6 @@ class SFM(Model):
|
||||
)
|
||||
)
|
||||
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.sfm_model = SFM_Model(
|
||||
d_feat=self.d_feat,
|
||||
output_dim=self.output_dim,
|
||||
@@ -287,24 +284,72 @@ class SFM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.train_optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=10,
|
||||
verbose=True,
|
||||
threshold=0.0001,
|
||||
threshold_mode="rel",
|
||||
cooldown=0,
|
||||
min_lr=0.00001,
|
||||
eps=1e-08,
|
||||
)
|
||||
|
||||
self._fitted = False
|
||||
self.sfm_model.to(self.device)
|
||||
|
||||
def fit(self, dataset: DatasetH, evals_result=dict(), verbose=True, save_path=None, **kwargs):
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.sfm_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.sfm_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.sfm_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.sfm_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.sfm_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
@@ -312,10 +357,10 @@ class SFM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = create_save_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
@@ -323,90 +368,51 @@ class SFM(Model):
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
|
||||
# prepare training data
|
||||
x_train_values = torch.from_numpy(x_train.values).float()
|
||||
y_train_values = torch.from_numpy(np.squeeze(y_train.values)).float()
|
||||
train_num = y_train_values.shape[0]
|
||||
|
||||
# prepare validation data
|
||||
x_val_auto = torch.from_numpy(x_valid.values).float()
|
||||
y_val_auto = torch.from_numpy(np.squeeze(y_valid.values)).float()
|
||||
|
||||
x_val_auto = x_val_auto.to(self.device)
|
||||
y_val_auto = y_val_auto.to(self.device)
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
if stop_steps >= self.early_stop:
|
||||
if verbose:
|
||||
self.logger.info("\tearly stop")
|
||||
break
|
||||
loss = AverageMeter()
|
||||
self.sfm_model.train()
|
||||
self.train_optimizer.zero_grad()
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
choice = np.random.choice(train_num, self.batch_size)
|
||||
x_batch_auto = x_train_values[choice]
|
||||
y_batch_auto = y_train_values[choice]
|
||||
|
||||
x_batch_auto = x_batch_auto.to(self.device)
|
||||
y_batch_auto = y_batch_auto.to(self.device)
|
||||
|
||||
# forward
|
||||
preds = self.sfm_model(x_batch_auto)
|
||||
cur_loss = self.get_loss(preds, y_batch_auto, self.loss_type)
|
||||
cur_loss.backward()
|
||||
self.train_optimizer.step()
|
||||
loss.update(cur_loss.item())
|
||||
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
# print(loss.val)
|
||||
if step and step % self.eval_steps == 0:
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.sfm_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
with torch.no_grad():
|
||||
self.sfm_model.eval()
|
||||
loss_val = AverageMeter()
|
||||
|
||||
# forward
|
||||
preds = self.sfm_model(x_val_auto)
|
||||
cur_loss_val = self.get_loss(preds, y_val_auto, self.loss_type)
|
||||
loss_val.update(cur_loss_val.item())
|
||||
|
||||
if verbose:
|
||||
self.logger.info(
|
||||
"[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
|
||||
)
|
||||
evals_result["train"].append(train_loss)
|
||||
evals_result["valid"].append(loss_val.val)
|
||||
if loss_val.val < best_loss:
|
||||
if verbose:
|
||||
self.logger.info(
|
||||
"\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format(
|
||||
best_loss, loss_val.val
|
||||
)
|
||||
)
|
||||
best_loss = loss_val.val
|
||||
stop_steps = 0
|
||||
torch.save(self.sfm_model.state_dict(), save_path)
|
||||
train_loss = 0
|
||||
# update learning rate
|
||||
self.scheduler.step(cur_loss_val)
|
||||
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
if self.device != "cpu":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_loss(self, pred, target, loss_type):
|
||||
if loss_type == "mse":
|
||||
sqr_loss = (pred - target) ** 2
|
||||
loss = sqr_loss.mean()
|
||||
return loss
|
||||
elif loss_type == "binary":
|
||||
loss = nn.BCELoss()
|
||||
return loss(pred, target)
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss": # use loss
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
@@ -414,34 +420,28 @@ class SFM(Model):
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
x_test = torch.from_numpy(x_test.values).float()
|
||||
|
||||
x_test = x_test.to(self.device)
|
||||
self.sfm_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
with torch.no_grad():
|
||||
if self.device != "cpu":
|
||||
preds = self.sfm_model(x_test).detach().cpu().numpy()
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
preds = self.sfm_model(x_test).detach().numpy()
|
||||
return pd.Series(preds, index=index)
|
||||
end = begin + self.batch_size
|
||||
|
||||
def save(self, filename, **kwargs):
|
||||
with save_multiple_parts_file(filename) as model_dir:
|
||||
model_path = os.path.join(model_dir, os.path.split(model_dir)[-1])
|
||||
# Save model
|
||||
torch.save(self.sfm_model.state_dict(), model_path)
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float()
|
||||
|
||||
def load(self, buffer, **kwargs):
|
||||
with unpack_archive_with_buffer(buffer) as model_dir:
|
||||
# Get model name
|
||||
_model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[
|
||||
0
|
||||
]
|
||||
_model_path = os.path.join(model_dir, _model_name)
|
||||
# Load model
|
||||
self.sfm_model.load_state_dict(torch.load(_model_path))
|
||||
self._fitted = True
|
||||
if self.device != "cpu":
|
||||
x_batch = x_batch.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.sfm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -13,10 +22,8 @@ from ...data.dataset.handler import DataHandlerLP
|
||||
class XGBModel(Model):
|
||||
"""XGBModel Model"""
|
||||
|
||||
def __init__(self, obj="mse", **kwargs):
|
||||
if obj not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self._params = {"obj": obj}
|
||||
def __init__(self, **kwargs):
|
||||
self._params = {}
|
||||
self._params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
|
||||
@@ -252,7 +252,7 @@ def model_performance_graph(
|
||||
"""Model performance
|
||||
|
||||
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score,
|
||||
label]**. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1")
|
||||
label]**. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1").
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
@@ -266,13 +266,13 @@ def model_performance_graph(
|
||||
|
||||
|
||||
:param lag: `pred.groupby(level='instrument')['score'].shift(lag)`. It will be only used in the auto-correlation computing.
|
||||
:param N: group number, default 5
|
||||
:param reverse: if `True`, `pred['score'] *= -1`
|
||||
:param rank: if **True**, calculate rank ic
|
||||
:param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover']
|
||||
:param show_notebook: whether to display graphics in notebook, the default is `True`
|
||||
:param show_nature_day: whether to display the abscissa of non-trading day
|
||||
:return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list
|
||||
:param N: group number, default 5.
|
||||
:param reverse: if `True`, `pred['score'] *= -1`.
|
||||
:param rank: if **True**, calculate rank ic.
|
||||
:param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'].
|
||||
:param show_notebook: whether to display graphics in notebook, the default is `True`.
|
||||
:param show_nature_day: whether to display the abscissa of non-trading day.
|
||||
:return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list.
|
||||
"""
|
||||
figure_list = []
|
||||
for graph_name in graph_names:
|
||||
|
||||
@@ -218,10 +218,10 @@ def cumulative_return_graph(
|
||||
|
||||
|
||||
Graph desc:
|
||||
- Axis X: Trading day
|
||||
- Axis X: Trading day.
|
||||
- Axis Y:
|
||||
- Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`
|
||||
- Below axis Y: Daily weight sum
|
||||
- Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`.
|
||||
- Below axis Y: Daily weight sum.
|
||||
- In the **sell** graph, `y < 0` stands for profit; in other cases, `y > 0` stands for profit.
|
||||
- In the **buy_minus_sell** graph, the **y** value of the **weight** graph at the bottom is `buy_weight + sell_weight`.
|
||||
- In each graph, the **red line** in the histogram on the right represents the average.
|
||||
|
||||
@@ -97,9 +97,9 @@ def rank_label_graph(
|
||||
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
|
||||
|
||||
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result
|
||||
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
|
||||
:param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**.
|
||||
**The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`
|
||||
**The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
@@ -115,7 +115,7 @@ def rank_label_graph(
|
||||
|
||||
:param start_date: start date
|
||||
:param end_date: end_date
|
||||
:param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures
|
||||
:param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures.
|
||||
:return:
|
||||
"""
|
||||
position = copy.deepcopy(position)
|
||||
|
||||
@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
|
||||
qcr.report_graph(report_normal_df)
|
||||
|
||||
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**
|
||||
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
@@ -200,8 +200,8 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
|
||||
|
||||
|
||||
:param show_notebook: whether to display graphics in notebook, the default is **True**
|
||||
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list
|
||||
:param show_notebook: whether to display graphics in notebook, the default is **True**.
|
||||
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
|
||||
"""
|
||||
report_df = report_df.copy()
|
||||
fig_list = _report_figure(report_df)
|
||||
|
||||
@@ -218,7 +218,7 @@ def risk_analysis_graph(
|
||||
max_drawdown -0.088263
|
||||
|
||||
|
||||
:param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**
|
||||
:param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
@@ -232,7 +232,7 @@ def risk_analysis_graph(
|
||||
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
|
||||
|
||||
|
||||
:param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**
|
||||
:param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
@@ -246,7 +246,7 @@ def risk_analysis_graph(
|
||||
2017-01-10 0.000824 -0.001944 -0.001120
|
||||
|
||||
|
||||
:param show_notebook: Whether to display graphics in a notebook, default **True**
|
||||
:param show_notebook: Whether to display graphics in a notebook, default **True**.
|
||||
If True, show graph in notebook
|
||||
If False, return graph figure
|
||||
:return:
|
||||
|
||||
@@ -36,7 +36,7 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
|
||||
analysis_position.score_ic_graph(pred_label)
|
||||
|
||||
|
||||
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**
|
||||
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
@@ -49,8 +49,8 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
|
||||
2017-12-15 -0.102778 -0.102778
|
||||
|
||||
|
||||
:param show_notebook: whether to display graphics in notebook, the default is **True**
|
||||
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list
|
||||
:param show_notebook: whether to display graphics in notebook, the default is **True**.
|
||||
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
|
||||
"""
|
||||
_ic_df = _get_score_ic(pred_label)
|
||||
# FIXME: support HIGH-FREQ
|
||||
|
||||
@@ -31,16 +31,16 @@ class BaseStrategy:
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Seires
|
||||
stock_id , score
|
||||
stock_id , score.
|
||||
current : Position()
|
||||
current state of position
|
||||
DO NOT directly change the state of current
|
||||
current state of position.
|
||||
DO NOT directly change the state of current.
|
||||
trade_exchange : Exchange()
|
||||
trade exchange
|
||||
trade exchange.
|
||||
pred_date : pd.Timestamp
|
||||
predict date
|
||||
predict date.
|
||||
trade_date : pd.Timestamp
|
||||
trade date
|
||||
trade date.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -49,11 +49,11 @@ class BaseStrategy:
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Series
|
||||
stock_id , score
|
||||
stock_id , score.
|
||||
pred_date : pd.Timestamp
|
||||
oredict date
|
||||
oredict date.
|
||||
trade_date : pd.Timestamp
|
||||
trade date
|
||||
trade date.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -67,7 +67,7 @@ class BaseStrategy:
|
||||
"""
|
||||
This method only be used in 'online' module, it will generate the *args to initial the strategy.
|
||||
:param
|
||||
mode : model used in 'online' module
|
||||
mode : model used in 'online' module.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@@ -82,7 +82,7 @@ class StrategyWrapper:
|
||||
def __init__(self, inner_strategy):
|
||||
"""__init__
|
||||
|
||||
:param inner_strategy: set the inner strategy
|
||||
:param inner_strategy: set the inner strategy.
|
||||
"""
|
||||
self.inner_strategy = inner_strategy
|
||||
|
||||
@@ -99,9 +99,9 @@ class AdjustTimer:
|
||||
Responsible for timing of position adjusting
|
||||
|
||||
This is designed as multiple inheritance mechanism due to:
|
||||
- the is_adjust may need access to the internel state of a strategy
|
||||
- the is_adjust may need access to the internel state of a strategy.
|
||||
|
||||
- it can be reguard as a enhancement to the existing strategy
|
||||
- it can be reguard as a enhancement to the existing strategy.
|
||||
"""
|
||||
|
||||
# adjust position in each trade date
|
||||
@@ -146,12 +146,12 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
|
||||
Parameters
|
||||
-----------
|
||||
score : pd.Series
|
||||
pred score for this trade date, index is stock_id, contain 'score' column
|
||||
pred score for this trade date, index is stock_id, contain 'score' column.
|
||||
current : Position()
|
||||
current position
|
||||
current position.
|
||||
trade_exchange : Exchange()
|
||||
trade_date : pd.Timestamp
|
||||
trade date
|
||||
trade date.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -160,13 +160,13 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Seires
|
||||
stock_id , score
|
||||
stock_id , score.
|
||||
current : Position()
|
||||
current of account
|
||||
current of account.
|
||||
trade_exchange : Exchange()
|
||||
exchange
|
||||
exchange.
|
||||
trade_date : pd.Timestamp
|
||||
date
|
||||
date.
|
||||
"""
|
||||
# judge if to adjust
|
||||
if not self.is_adjust(trade_date):
|
||||
@@ -206,26 +206,26 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
Parameters
|
||||
-----------
|
||||
topk : int
|
||||
The number of stocks in the portfolio
|
||||
the number of stocks in the portfolio.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date
|
||||
number of stocks to be replaced in each trading date.
|
||||
method_sell : str
|
||||
dropout method_sell, random/bottom
|
||||
dropout method_sell, random/bottom.
|
||||
method_buy : str
|
||||
dropout method_buy, random/top
|
||||
dropout method_buy, random/top.
|
||||
risk_degree : float
|
||||
position percentage of total value
|
||||
position percentage of total value.
|
||||
thresh : int
|
||||
minimun holding days since last buy singal of the stock
|
||||
minimun holding days since last buy singal of the stock.
|
||||
hold_thresh : int
|
||||
minimum holding days
|
||||
before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh
|
||||
before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh.
|
||||
only_tradable : bool
|
||||
will the strategy only consider the tradable stock when buying and selling.
|
||||
if only_tradable:
|
||||
strategy will make buy sell decision without checking the tradable state of the stock
|
||||
strategy will make buy sell decision without checking the tradable state of the stock.
|
||||
else:
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them
|
||||
strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__()
|
||||
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
|
||||
@@ -245,7 +245,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
def get_risk_degree(self, date):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing
|
||||
Dynamically risk_degree will result in Market timing.
|
||||
"""
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return self.risk_degree
|
||||
@@ -257,15 +257,15 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Series
|
||||
stock_id , score
|
||||
stock_id , score.
|
||||
current : Position()
|
||||
current of account
|
||||
current of account.
|
||||
trade_exchange : Exchange()
|
||||
exchange
|
||||
exchange.
|
||||
pred_date : pd.Timestamp
|
||||
predict date
|
||||
predict date.
|
||||
trade_date : pd.Timestamp
|
||||
trade date
|
||||
trade date.
|
||||
"""
|
||||
if not self.is_adjust(trade_date):
|
||||
return []
|
||||
|
||||
@@ -129,13 +129,13 @@ class Expression(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
instrument : str
|
||||
instrument code
|
||||
instrument code.
|
||||
start_index : str
|
||||
feature start index [in calendar]
|
||||
feature start index [in calendar].
|
||||
end_index : str
|
||||
feature end index [in calendar]
|
||||
feature end index [in calendar].
|
||||
freq : str
|
||||
feature frequency
|
||||
feature frequency.
|
||||
|
||||
Returns
|
||||
----------
|
||||
|
||||
@@ -76,8 +76,8 @@ class MemCache(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mem_cache_size_limit: cache max size
|
||||
limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof)
|
||||
mem_cache_size_limit: cache max size.
|
||||
limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof).
|
||||
"""
|
||||
if limit_type not in ["length", "sizeof"]:
|
||||
raise ValueError(f"limit_type must be length or sizeof, your limit_type is {limit_type}")
|
||||
@@ -118,9 +118,9 @@ class MemCacheExpire:
|
||||
def set_cache(mem_cache, key, value):
|
||||
"""set cache
|
||||
|
||||
:param mem_cache: MemCache attribute('c'/'i'/'f')
|
||||
:param key: cache key
|
||||
:param value: cache value
|
||||
:param mem_cache: MemCache attribute('c'/'i'/'f').
|
||||
:param key: cache key.
|
||||
:param value: cache value.
|
||||
"""
|
||||
mem_cache[key] = value, time.time()
|
||||
|
||||
@@ -128,9 +128,9 @@ class MemCacheExpire:
|
||||
def get_cache(mem_cache, key):
|
||||
"""get mem cache
|
||||
|
||||
:param mem_cache: MemCache attribute('c'/'i'/'f')
|
||||
:param key: cache key
|
||||
:return: cache value; if cache not exist, return None
|
||||
:param mem_cache: MemCache attribute('c'/'i'/'f').
|
||||
:param key: cache key.
|
||||
:return: cache value; if cache not exist, return None.
|
||||
"""
|
||||
value = None
|
||||
expire = False
|
||||
@@ -275,12 +275,12 @@ class ExpressionCache(BaseProviderCache):
|
||||
Parameters
|
||||
----------
|
||||
cache_uri : str
|
||||
the complete uri of expression cache file (include dir path)
|
||||
the complete uri of expression cache file (include dir path).
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
0(successful update)/ 1(no need to update)/ 2(update failure)
|
||||
0(successful update)/ 1(no need to update)/ 2(update failure).
|
||||
"""
|
||||
raise NotImplementedError("Implement this method if you want to make expression cache up to date")
|
||||
|
||||
@@ -348,7 +348,7 @@ class DatasetCache(BaseProviderCache):
|
||||
Parameters
|
||||
----------
|
||||
cache_uri : str
|
||||
the complete uri of dataset cache file (include dir path)
|
||||
the complete uri of dataset cache file (include dir path).
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -361,9 +361,9 @@ class DatasetCache(BaseProviderCache):
|
||||
def cache_to_origin_data(data, fields):
|
||||
"""cache data to origin data
|
||||
|
||||
:param data: pd.DataFrame, cache data
|
||||
:param fields: feature fields
|
||||
:return: pd.DataFrame
|
||||
:param data: pd.DataFrame, cache data.
|
||||
:param fields: feature fields.
|
||||
:return: pd.DataFrame.
|
||||
"""
|
||||
not_space_fields = remove_fields_space(fields)
|
||||
data = data.loc[:, not_space_fields]
|
||||
@@ -583,7 +583,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
:param cache_path:
|
||||
:param start_time:
|
||||
:param end_time:
|
||||
:param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent
|
||||
:param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent.
|
||||
:return:
|
||||
"""
|
||||
|
||||
@@ -771,12 +771,12 @@ class DiskDatasetCache(DatasetCache):
|
||||
|
||||
- This is a hdf file sorted by datetime
|
||||
|
||||
:param cache_path: The path to store the cache
|
||||
:param instruments: The instruments to store the cache
|
||||
:param fields: The fields to store the cache
|
||||
:param freq: The freq to store the cache
|
||||
:param cache_path: The path to store the cache.
|
||||
:param instruments: The instruments to store the cache.
|
||||
:param fields: The fields to store the cache.
|
||||
:param freq: The freq to store the cache.
|
||||
|
||||
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function
|
||||
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
|
||||
"""
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
|
||||
@@ -51,13 +51,13 @@ class Client(object):
|
||||
Parameters
|
||||
----------
|
||||
request_type : str
|
||||
type of proposed request, 'calendar'/'instrument'/'feature'
|
||||
type of proposed request, 'calendar'/'instrument'/'feature'.
|
||||
request_content : dict
|
||||
records the information of the request
|
||||
records the information of the request.
|
||||
msg_proc_func : func
|
||||
the function to process the message when receiving response, should have arg `*args`
|
||||
the function to process the message when receiving response, should have arg `*args`.
|
||||
msg_queue: Queue
|
||||
The queue to pass the messsage after callback
|
||||
The queue to pass the messsage after callback.
|
||||
"""
|
||||
head_info = {"version": qlib.__version__}
|
||||
|
||||
|
||||
@@ -41,13 +41,13 @@ class CalendarProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
start_time : str
|
||||
start of the time range
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end of the time range
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency, available: year/quarter/month/week/day
|
||||
time frequency, available: year/quarter/month/week/day.
|
||||
future : bool
|
||||
whether including future trading day
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -62,24 +62,24 @@ class CalendarProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
start_time : str
|
||||
start of the time range
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end of the time range
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency, available: year/quarter/month/week/day
|
||||
time frequency, available: year/quarter/month/week/day.
|
||||
future : bool
|
||||
whether including future trading day
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Timestamp
|
||||
the real start time
|
||||
the real start time.
|
||||
pd.Timestamp
|
||||
the real end time
|
||||
the real end time.
|
||||
int
|
||||
the index of start time
|
||||
the index of start time.
|
||||
int
|
||||
the index of end time
|
||||
the index of end time.
|
||||
"""
|
||||
start_time = pd.Timestamp(start_time)
|
||||
end_time = pd.Timestamp(end_time)
|
||||
@@ -103,16 +103,16 @@ class CalendarProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file
|
||||
frequency of read calendar file.
|
||||
future : bool
|
||||
whether including future trading day
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
list of timestamps
|
||||
list of timestamps.
|
||||
dict
|
||||
dict composed by timestamp as key and index as value for fast search
|
||||
dict composed by timestamp as key and index as value for fast search.
|
||||
"""
|
||||
flag = f"{freq}_future_{future}"
|
||||
if flag in H["c"]:
|
||||
@@ -141,14 +141,14 @@ class InstrumentProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
market : str
|
||||
market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500
|
||||
market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500.
|
||||
filter_pipe : list
|
||||
the list of dynamic filters
|
||||
the list of dynamic filters.
|
||||
|
||||
Returns
|
||||
----------
|
||||
dict
|
||||
dict of stockpool config
|
||||
dict of stockpool config.
|
||||
{`market`=>base market name, `filter_pipe`=>list of filters}
|
||||
|
||||
example :
|
||||
@@ -182,13 +182,13 @@ class InstrumentProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
instruments : dict
|
||||
stockpool config
|
||||
stockpool config.
|
||||
start_time : str
|
||||
start of the time range
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end of the time range
|
||||
end of the time range.
|
||||
as_list : bool
|
||||
return instruments as list or dict
|
||||
return instruments as list or dict.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -243,15 +243,15 @@ class FeatureProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
instrument : str
|
||||
a certain instrument
|
||||
a certain instrument.
|
||||
field : str
|
||||
a certain field of feature
|
||||
a certain field of feature.
|
||||
start_time : str
|
||||
start of the time range
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end of the time range
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency, available: year/quarter/month/week/day
|
||||
time frequency, available: year/quarter/month/week/day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -294,15 +294,15 @@ class ExpressionProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
instrument : str
|
||||
a certain instrument
|
||||
a certain instrument.
|
||||
field : str
|
||||
a certain field of feature
|
||||
a certain field of feature.
|
||||
start_time : str
|
||||
start of the time range
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end of the time range
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency, available: year/quarter/month/week/day
|
||||
time frequency, available: year/quarter/month/week/day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -325,20 +325,20 @@ class DatasetProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
instruments : list or dict
|
||||
list/dict of instruments or dict of stockpool config
|
||||
list/dict of instruments or dict of stockpool config.
|
||||
fields : list
|
||||
list of feature instances
|
||||
list of feature instances.
|
||||
start_time : str
|
||||
start of the time range
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end of the time range
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency
|
||||
time frequency.
|
||||
|
||||
Returns
|
||||
----------
|
||||
pd.DataFrame
|
||||
a pandas dataframe with <instrument, datetime> index
|
||||
a pandas dataframe with <instrument, datetime> index.
|
||||
"""
|
||||
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")
|
||||
|
||||
@@ -357,17 +357,17 @@ class DatasetProvider(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
instruments : list or dict
|
||||
list/dict of instruments or dict of stockpool config
|
||||
list/dict of instruments or dict of stockpool config.
|
||||
fields : list
|
||||
list of feature instances
|
||||
list of feature instances.
|
||||
start_time : str
|
||||
start of the time range
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end of the time range
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency
|
||||
time frequency.
|
||||
disk_cache : int
|
||||
whether to skip(0)/use(1)/replace(2) disk_cache
|
||||
whether to skip(0)/use(1)/replace(2) disk_cache.
|
||||
|
||||
"""
|
||||
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
|
||||
@@ -526,7 +526,7 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file
|
||||
frequency of read calendar file.
|
||||
|
||||
Returns
|
||||
----------
|
||||
|
||||
@@ -17,8 +17,8 @@ class Dataset(Serializable):
|
||||
init is designed to finish following steps:
|
||||
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing
|
||||
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
|
||||
|
||||
- initialize the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
@@ -29,17 +29,17 @@ class Dataset(Serializable):
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
"""
|
||||
setup the data
|
||||
Setup the data.
|
||||
|
||||
We split the setup_data function for following situation:
|
||||
|
||||
- User have a Dataset object with learned status on disk
|
||||
- User have a Dataset object with learned status on disk.
|
||||
|
||||
- User load the Dataset object from the disk(Note the init function is skiped)
|
||||
- User load the Dataset object from the disk(Note the init function is skiped).
|
||||
|
||||
- User call `setup_data` to load new data
|
||||
- User call `setup_data` to load new data.
|
||||
|
||||
- User prepare data for model based on previous status
|
||||
- User prepare data for model based on previous status.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -66,9 +66,10 @@ class DatasetH(Dataset):
|
||||
|
||||
User should try to put the data preprocessing functions into handler.
|
||||
Only following data processing functions should be placed in Dataset:
|
||||
|
||||
- The processing is related to specific model.
|
||||
|
||||
- The processing is related to data split
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Union[dict, DataHandler], segments: list):
|
||||
@@ -76,15 +77,15 @@ class DatasetH(Dataset):
|
||||
Parameters
|
||||
----------
|
||||
handler : Union[dict, DataHandler]
|
||||
handler will be passed into setup_data
|
||||
handler will be passed into setup_data.
|
||||
segments : list
|
||||
handler will be passed into setup_data
|
||||
handler will be passed into setup_data.
|
||||
"""
|
||||
super().__init__(handler, segments)
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: list):
|
||||
"""
|
||||
setup the underlying data
|
||||
Setup the underlying data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -94,12 +95,13 @@ class DatasetH(Dataset):
|
||||
- insntance of `DataHandler`
|
||||
|
||||
- config of `DataHandler`. Please refer to `DataHandler`
|
||||
|
||||
segments : list
|
||||
Describe the options to segment the data.
|
||||
Here are some examples:
|
||||
|
||||
.. code-block::
|
||||
|
||||
|
||||
1) 'segments': {
|
||||
'train': ("2008-01-01", "2014-12-31"),
|
||||
'valid': ("2017-01-01", "2020-08-01",),
|
||||
@@ -121,7 +123,7 @@ class DatasetH(Dataset):
|
||||
**kwargs,
|
||||
) -> Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
"""
|
||||
prepare the data for learning and inference
|
||||
Prepare the data for learning and inference.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -132,11 +134,12 @@ class DatasetH(Dataset):
|
||||
- 'train'
|
||||
|
||||
- ['train', 'valid']
|
||||
|
||||
col_set : str
|
||||
The col_set will be passed to self._handler when fetching data
|
||||
data_key: str
|
||||
The col_set will be passed to self._handler when fetching data.
|
||||
data_key : str
|
||||
The data to fetch: DK_*
|
||||
Default is DK_I, which indicate fetching data for **inference**
|
||||
Default is DK_I, which indicate fetching data for **inference**.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -29,7 +29,7 @@ class DataHandler(Serializable):
|
||||
"""
|
||||
The steps to using a handler
|
||||
1. initialized data handler (call by `init`).
|
||||
2. use the data
|
||||
2. use the data.
|
||||
|
||||
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
@@ -65,17 +65,17 @@ class DataHandler(Serializable):
|
||||
Parameters
|
||||
----------
|
||||
instruments :
|
||||
The stock list to retrive
|
||||
The stock list to retrive.
|
||||
start_time :
|
||||
start_time of the original data
|
||||
start_time of the original data.
|
||||
end_time :
|
||||
end_time of the original data
|
||||
end_time of the original data.
|
||||
data_loader : Tuple[dict, str, DataLoader]
|
||||
data loader to load the data
|
||||
data loader to load the data.
|
||||
init_data :
|
||||
intialize the original data in the constructor
|
||||
intialize the original data in the constructor.
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible
|
||||
Return the original data instead of copy if possible.
|
||||
"""
|
||||
# Set logger
|
||||
self.logger = get_module_logger("DataHandler")
|
||||
@@ -219,9 +219,9 @@ class DataHandler(Serializable):
|
||||
get a iterator of sliced data with given periods
|
||||
|
||||
Args:
|
||||
periods (int): number of periods
|
||||
min_periods (int): minimum periods for sliced dataframe
|
||||
kwargs (dict): will be passed to `self.fetch`
|
||||
periods (int): number of periods.
|
||||
min_periods (int): minimum periods for sliced dataframe.
|
||||
kwargs (dict): will be passed to `self.fetch`.
|
||||
"""
|
||||
trading_dates = self._data.index.unique(level="datetime")
|
||||
if min_periods is None:
|
||||
@@ -243,10 +243,10 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
# process type
|
||||
PTYPE_I = "independent"
|
||||
# - self._infer will processed by infer_processors
|
||||
# - self._infer will be processed by infer_processors
|
||||
# - self._learn will be processed by learn_processors
|
||||
PTYPE_A = "append"
|
||||
# - self._infer will processed by infer_processors
|
||||
# - self._infer will be processed by infer_processors
|
||||
# - self._learn will be processed by infer_processors + learn_processors
|
||||
# - (e.g. self._infer processed by learn_processors )
|
||||
|
||||
@@ -265,30 +265,40 @@ class DataHandlerLP(DataHandler):
|
||||
Parameters
|
||||
----------
|
||||
infer_processors : list
|
||||
list of <description info> of processors to generate data for inference
|
||||
example of <description info>:
|
||||
1) classname & kwargs:
|
||||
{
|
||||
"class": "MinMaxNorm",
|
||||
"kwargs": {
|
||||
"fit_start_time": "20080101",
|
||||
"fit_end_time": "20121231"
|
||||
- list of <description info> of processors to generate data for inference
|
||||
|
||||
- example of <description info>:
|
||||
|
||||
.. code-block::
|
||||
|
||||
1) classname & kwargs:
|
||||
{
|
||||
"class": "MinMaxNorm",
|
||||
"kwargs": {
|
||||
"fit_start_time": "20080101",
|
||||
"fit_end_time": "20121231"
|
||||
}
|
||||
}
|
||||
}
|
||||
2) Only classname:
|
||||
"DropnaFeature"
|
||||
3) object instance of Processor
|
||||
2) Only classname:
|
||||
"DropnaFeature"
|
||||
3) object instance of Processor
|
||||
|
||||
learn_processors : list
|
||||
similar to infer_processors, but for generating data for learning models
|
||||
|
||||
process_type: str
|
||||
PTYPE_I = 'independent'
|
||||
|
||||
- self._infer will processed by infer_processors
|
||||
|
||||
- self._learn will be processed by learn_processors
|
||||
|
||||
PTYPE_A = 'append'
|
||||
|
||||
- self._infer will processed by infer_processors
|
||||
|
||||
- self._learn will be processed by infer_processors + learn_processors
|
||||
|
||||
- (e.g. self._infer processed by learn_processors )
|
||||
"""
|
||||
|
||||
@@ -377,7 +387,7 @@ class DataHandlerLP(DataHandler):
|
||||
Parameters
|
||||
----------
|
||||
init_type : str
|
||||
The type `IT_*` listed above
|
||||
The type `IT_*` listed above.
|
||||
enable_cache : bool
|
||||
default value is false:
|
||||
|
||||
@@ -419,13 +429,13 @@ class DataHandlerLP(DataHandler):
|
||||
Parameters
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str]
|
||||
describe how to select data by index
|
||||
describe how to select data by index.
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
which index level to select the data.
|
||||
col_set : str
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
data_key: str
|
||||
The data to fetch: DK_*
|
||||
select a set of meaningful columns.(e.g. features, columns).
|
||||
data_key : str
|
||||
the data to fetch: DK_*.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -443,9 +453,9 @@ class DataHandlerLP(DataHandler):
|
||||
Parameters
|
||||
----------
|
||||
col_set : str
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
data_key: str
|
||||
The data to fetch: DK_*
|
||||
select a set of meaningful columns.(e.g. features, columns).
|
||||
data_key : str
|
||||
the data to fetch: DK_*.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -21,27 +21,11 @@ class DataLoader(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
"""
|
||||
load the data as pd.DataFrame
|
||||
load the data as pd.DataFrame.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
instruments : [TODO:type]
|
||||
[TODO:description]
|
||||
start_time : [TODO:type]
|
||||
[TODO:description]
|
||||
end_time : [TODO:type]
|
||||
[TODO:description]
|
||||
Example of the data (The multi-index of the columns is optional.):
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
|
||||
Example of the data (The multi-index of the columns is optional.):
|
||||
|
||||
.. code-block::
|
||||
.. code-block:: python
|
||||
|
||||
feature label
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
@@ -49,6 +33,21 @@ class DataLoader(abc.ABC):
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
instruments : str or dict
|
||||
it can either be the market name or the config file of instruments generated by InstrumentProvider.
|
||||
start_time : str
|
||||
start of the time range.
|
||||
end_time : str
|
||||
end of the time range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -67,7 +66,7 @@ class DLWParser(DataLoader):
|
||||
config : Tuple[list, tuple, dict]
|
||||
Config will be used to describe the fields and column names
|
||||
|
||||
.. code-block:: YAML
|
||||
.. code-block::
|
||||
|
||||
<config> := {
|
||||
"group_name1": <fields_info1>
|
||||
@@ -102,16 +101,16 @@ class DLWParser(DataLoader):
|
||||
Parameters
|
||||
----------
|
||||
instruments :
|
||||
the instruments
|
||||
the instruments.
|
||||
exprs : list
|
||||
The expressions to describe the content of the data
|
||||
the expressions to describe the content of the data.
|
||||
names : list
|
||||
The name of the data
|
||||
the name of the data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
the queried dataframe
|
||||
the queried dataframe.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ def get_group_columns(df: pd.DataFrame, group: str):
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
with multi of columns
|
||||
with multi of columns.
|
||||
group : str
|
||||
the name of the feature group, i.e. the first level value of the group index.
|
||||
"""
|
||||
@@ -56,7 +56,7 @@ class Processor(Serializable):
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The raw_df of handler or result from previous processor
|
||||
The raw_df of handler or result from previous processor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -68,7 +68,7 @@ class Processor(Serializable):
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
if it is usable for infenrece
|
||||
if it is usable for infenrece.
|
||||
"""
|
||||
return True
|
||||
|
||||
@@ -176,7 +176,9 @@ class MinMaxNorm(Processor):
|
||||
return df
|
||||
|
||||
|
||||
class ZscoreNorm(Processor):
|
||||
class ZScoreNorm(Processor):
|
||||
"""ZScore Normalization"""
|
||||
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
@@ -203,6 +205,42 @@ class ZscoreNorm(Processor):
|
||||
return df
|
||||
|
||||
|
||||
class RobustZScoreNorm(Processor):
|
||||
"""Robust ZScore Normalization
|
||||
|
||||
Use robust statistics for Z-Score normalization:
|
||||
mean(x) = median(x)
|
||||
std(x) = MAD(x) * 1.4826
|
||||
|
||||
Reference:
|
||||
https://en.wikipedia.org/wiki/Median_absolute_deviation.
|
||||
"""
|
||||
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None, clip_outlier=True):
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
self.clip_outlier = clip_outlier
|
||||
|
||||
def fit(self, df):
|
||||
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
self.cols = get_group_columns(df, self.fields_group)
|
||||
X = df[self.cols].values
|
||||
self.mean_train = np.nanmedian(X, axis=0)
|
||||
self.std_train = np.nanmedian(np.abs(X - self.mean_train), axis=0)
|
||||
self.std_train += EPS
|
||||
self.std_train *= 1.4826
|
||||
|
||||
def __call__(self, df):
|
||||
X = df[self.cols]
|
||||
X -= self.mean_train
|
||||
X /= self.std_train
|
||||
df[self.cols] = X
|
||||
if self.clip_outlier:
|
||||
df.clip(-3, 3, inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
class CSZScoreNorm(Processor):
|
||||
"""Cross Sectional ZScore Normalization"""
|
||||
|
||||
|
||||
@@ -51,6 +51,9 @@ def fetch_df_by_index(
|
||||
-------
|
||||
Data of the given index.
|
||||
"""
|
||||
# level = None -> use selector directly
|
||||
if level == None:
|
||||
return df.loc(axis=0)[selector]
|
||||
# Try to get the right index
|
||||
idx_slc = (selector, slice(None, None))
|
||||
if get_level_index(df, level) == 1:
|
||||
|
||||
@@ -32,7 +32,7 @@ class BaseDFilter(abc.ABC):
|
||||
Parameters
|
||||
----------
|
||||
config : dict
|
||||
dict of config parameters
|
||||
dict of config parameters.
|
||||
"""
|
||||
raise NotImplementedError("Subclass of BaseDFilter must reimplement `from_config` method")
|
||||
|
||||
@@ -43,7 +43,7 @@ class BaseDFilter(abc.ABC):
|
||||
Returns
|
||||
----------
|
||||
dict
|
||||
return the dict of config parameters
|
||||
return the dict of config parameters.
|
||||
"""
|
||||
raise NotImplementedError("Subclass of BaseDFilter must reimplement `to_config` method")
|
||||
|
||||
@@ -69,9 +69,9 @@ class SeriesDFilter(BaseDFilter):
|
||||
Parameters
|
||||
----------
|
||||
fstart_time: str
|
||||
the time for the filter rule to start filter the instruments
|
||||
the time for the filter rule to start filter the instruments.
|
||||
fend_time: str
|
||||
the time for the filter rule to stop filter the instruments
|
||||
the time for the filter rule to stop filter the instruments.
|
||||
"""
|
||||
super(SeriesDFilter, self).__init__()
|
||||
self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None
|
||||
@@ -83,12 +83,12 @@ class SeriesDFilter(BaseDFilter):
|
||||
Parameters
|
||||
----------
|
||||
instruments: dict
|
||||
the dict of instruments in the form {instrument_name => list of timestamp tuple}
|
||||
the dict of instruments in the form {instrument_name => list of timestamp tuple}.
|
||||
|
||||
Returns
|
||||
----------
|
||||
pd.Timestamp, pd.Timestamp
|
||||
the lower time bound and upper time bound of all the instruments
|
||||
the lower time bound and upper time bound of all the instruments.
|
||||
"""
|
||||
trange = Cal.calendar(freq=self.filter_freq)
|
||||
ubound, lbound = trange[0], trange[-1]
|
||||
@@ -105,14 +105,14 @@ class SeriesDFilter(BaseDFilter):
|
||||
Parameters
|
||||
----------
|
||||
time_range : D.calendar
|
||||
the time range of the instruments
|
||||
the time range of the instruments.
|
||||
target_timestamp : list
|
||||
the list of tuple (timestamp, timestamp)
|
||||
the list of tuple (timestamp, timestamp).
|
||||
|
||||
Returns
|
||||
----------
|
||||
pd.Series
|
||||
the series of bool value for an instrument
|
||||
the series of bool value for an instrument.
|
||||
"""
|
||||
# Construct a whole dict of {date => bool}
|
||||
timestamp_series = {timestamp: False for timestamp in time_range}
|
||||
@@ -124,19 +124,19 @@ class SeriesDFilter(BaseDFilter):
|
||||
return timestamp_series
|
||||
|
||||
def _filterSeries(self, timestamp_series, filter_series):
|
||||
"""Filter the timestamp series with filter series by using element-wise AND operation of the two series
|
||||
"""Filter the timestamp series with filter series by using element-wise AND operation of the two series.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timestamp_series : pd.Series
|
||||
the series of bool value indicating existing time
|
||||
the series of bool value indicating existing time.
|
||||
filter_series : pd.Series
|
||||
the series of bool value indicating filter feature
|
||||
the series of bool value indicating filter feature.
|
||||
|
||||
Returns
|
||||
----------
|
||||
pd.Series
|
||||
the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp
|
||||
the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp.
|
||||
"""
|
||||
fstart, fend = list(filter_series.keys())[0], list(filter_series.keys())[-1]
|
||||
filter_series = filter_series.astype("bool") # Make sure the filter_series is boolean
|
||||
@@ -144,17 +144,17 @@ class SeriesDFilter(BaseDFilter):
|
||||
return timestamp_series
|
||||
|
||||
def _toTimestamp(self, timestamp_series):
|
||||
"""Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE
|
||||
"""Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timestamp_series: pd.Series
|
||||
the series of bool value after being filtered
|
||||
the series of bool value after being filtered.
|
||||
|
||||
Returns
|
||||
----------
|
||||
list
|
||||
the list of tuple (timestamp, timestamp)
|
||||
the list of tuple (timestamp, timestamp).
|
||||
"""
|
||||
# sort the timestamp_series according to the timestamps
|
||||
timestamp_series.sort_index()
|
||||
@@ -194,18 +194,18 @@ class SeriesDFilter(BaseDFilter):
|
||||
Parameters
|
||||
----------
|
||||
instruments : dict
|
||||
the dict of instruments to be filtered
|
||||
the dict of instruments to be filtered.
|
||||
fstart : pd.Timestamp
|
||||
start time of filter
|
||||
start time of filter.
|
||||
fend : pd.Timestamp
|
||||
end time of filter
|
||||
end time of filter.
|
||||
|
||||
.. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time
|
||||
.. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time.
|
||||
|
||||
Returns
|
||||
----------
|
||||
pd.Dataframe
|
||||
a series of {pd.Timestamp => bool}
|
||||
a series of {pd.Timestamp => bool}.
|
||||
"""
|
||||
raise NotImplementedError("Subclass of SeriesDFilter must reimplement `getFilterSeries` method")
|
||||
|
||||
@@ -215,16 +215,16 @@ class SeriesDFilter(BaseDFilter):
|
||||
Parameters
|
||||
----------
|
||||
instruments: dict
|
||||
input instruments to be filtered
|
||||
input instruments to be filtered.
|
||||
start_time: str
|
||||
start of the time range
|
||||
start of the time range.
|
||||
end_time: str
|
||||
end of the time range
|
||||
end of the time range.
|
||||
|
||||
Returns
|
||||
----------
|
||||
dict
|
||||
filtered instruments, same structure as input instruments
|
||||
filtered instruments, same structure as input instruments.
|
||||
"""
|
||||
lbound, ubound = self._getTimeBound(instruments)
|
||||
start_time = pd.Timestamp(start_time or lbound)
|
||||
@@ -272,7 +272,7 @@ class NameDFilter(SeriesDFilter):
|
||||
params:
|
||||
------
|
||||
name_rule_re: str
|
||||
regular expression for the name rule
|
||||
regular expression for the name rule.
|
||||
"""
|
||||
super(NameDFilter, self).__init__(fstart_time, fend_time)
|
||||
self.name_rule_re = name_rule_re
|
||||
@@ -325,13 +325,13 @@ class ExpressionDFilter(SeriesDFilter):
|
||||
params:
|
||||
------
|
||||
fstart_time: str
|
||||
filter the feature starting from this time
|
||||
filter the feature starting from this time.
|
||||
fend_time: str
|
||||
filter the feature ending by this time
|
||||
filter the feature ending by this time.
|
||||
rule_expression: str
|
||||
an input expression for the rule
|
||||
an input expression for the rule.
|
||||
keep: bool
|
||||
whether to keep the instruments of which features don't exist in the filter time span
|
||||
whether to keep the instruments of which features don't exist in the filter time span.
|
||||
"""
|
||||
super(ExpressionDFilter, self).__init__(fstart_time, fend_time)
|
||||
self.rule_expression = rule_expression
|
||||
|
||||
@@ -18,9 +18,8 @@ try:
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi
|
||||
except ImportError as err:
|
||||
print(err)
|
||||
print("Do not import qlib package in the repository directory")
|
||||
sys.exit(-1)
|
||||
print("Do not import qlib package in the repository directory!")
|
||||
raise
|
||||
|
||||
__all__ = (
|
||||
"Ref",
|
||||
@@ -865,6 +864,8 @@ class Skew(Rolling):
|
||||
"""
|
||||
|
||||
def __init__(self, feature, N):
|
||||
if N != 0 and N < 3:
|
||||
raise ValueError("The rolling window size of Skewness operation should >= 3")
|
||||
super(Skew, self).__init__(feature, N, "skew")
|
||||
|
||||
|
||||
@@ -885,6 +886,8 @@ class Kurt(Rolling):
|
||||
"""
|
||||
|
||||
def __init__(self, feature, N):
|
||||
if N != 0 and N < 4:
|
||||
raise ValueError("The rolling window size of Kurtosis operation should >= 5")
|
||||
super(Kurt, self).__init__(feature, N, "kurt")
|
||||
|
||||
|
||||
@@ -1268,7 +1271,7 @@ class WMA(Rolling):
|
||||
|
||||
def weighted_mean(x):
|
||||
w = np.arange(len(x))
|
||||
w /= w.sum()
|
||||
w = w / w.sum()
|
||||
return np.nanmean(w * x)
|
||||
|
||||
if self.N == 0:
|
||||
|
||||
@@ -33,7 +33,7 @@ class Model(BaseModel):
|
||||
Parameters
|
||||
----------
|
||||
dataset : Dataset
|
||||
dataset will generate the processed data from model training
|
||||
dataset will generate the processed data from model training.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -44,7 +44,7 @@ class Model(BaseModel):
|
||||
Parameters
|
||||
----------
|
||||
dataset : Dataset
|
||||
dataset will generate the processed dataset from model training
|
||||
dataset will generate the processed dataset from model training.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -59,6 +59,6 @@ class ModelFT(Model):
|
||||
Parameters
|
||||
----------
|
||||
dataset : Dataset
|
||||
dataset will generate the processed dataset from model training
|
||||
dataset will generate the processed dataset from model training.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -23,9 +23,9 @@ class RiskModel(BaseModel):
|
||||
def __init__(self, nan_option: str = "ignore", assume_centered: bool = False, scale_return: bool = True):
|
||||
"""
|
||||
Args:
|
||||
nan_option (str): nan handling option (`ignore`/`mask`/`fill`)
|
||||
assume_centered (bool): whether the data is assumed to be centered
|
||||
scale_return (bool): whether scale returns as percentage
|
||||
nan_option (str): nan handling option (`ignore`/`mask`/`fill`).
|
||||
assume_centered (bool): whether the data is assumed to be centered.
|
||||
scale_return (bool): whether scale returns as percentage.
|
||||
"""
|
||||
# nan
|
||||
assert nan_option in [
|
||||
@@ -45,11 +45,11 @@ class RiskModel(BaseModel):
|
||||
Args:
|
||||
X (pd.Series, pd.DataFrame or np.ndarray): data from which to estimate the covariance,
|
||||
with variables as columns and observations as rows.
|
||||
return_corr (bool): whether return the correlation matrix
|
||||
is_price (bool): whether `X` contains price (if not assume stock returns)
|
||||
return_corr (bool): whether return the correlation matrix.
|
||||
is_price (bool): whether `X` contains price (if not assume stock returns).
|
||||
|
||||
Returns:
|
||||
pd.DataFrame or np.ndarray: estimated covariance (or correlation)
|
||||
pd.DataFrame or np.ndarray: estimated covariance (or correlation).
|
||||
"""
|
||||
# transform input into 2D array
|
||||
if not isinstance(X, (pd.Series, pd.DataFrame)):
|
||||
@@ -101,10 +101,10 @@ class RiskModel(BaseModel):
|
||||
By default, this method implements the empirical covariance estimation.
|
||||
|
||||
Args:
|
||||
X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows)
|
||||
X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).
|
||||
|
||||
Returns:
|
||||
np.ndarray: covariance matrix
|
||||
np.ndarray: covariance matrix.
|
||||
"""
|
||||
xTx = np.asarray(X.T.dot(X))
|
||||
N = len(X)
|
||||
@@ -117,7 +117,7 @@ class RiskModel(BaseModel):
|
||||
"""handle nan and centerize data
|
||||
|
||||
Note:
|
||||
if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`
|
||||
if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`.
|
||||
"""
|
||||
# handle nan
|
||||
if self.nan_option == self.FILL_NAN:
|
||||
@@ -139,15 +139,15 @@ class ShrinkCovEstimator(RiskModel):
|
||||
where `alpha` is the shrink parameter and `F` is the shrinking target.
|
||||
|
||||
The following shrinking parameters (`alpha`) are supported:
|
||||
- `lw` [1][2][3]: use Ledoit-Wolf shrinking parameter
|
||||
- `oas` [4]: use Oracle Approximating Shrinkage shrinking parameter
|
||||
- float: directly specify the shrink parameter, should be between [0, 1]
|
||||
- `lw` [1][2][3]: use Ledoit-Wolf shrinking parameter.
|
||||
- `oas` [4]: use Oracle Approximating Shrinkage shrinking parameter.
|
||||
- float: directly specify the shrink parameter, should be between [0, 1].
|
||||
|
||||
The following shrinking targets (`F`) are supported:
|
||||
- `const_var` [1][4][5]: assume stocks have the same constant variance and zero correlation
|
||||
- `const_corr` [2][6]: assume stocks have different variance but equal correlation
|
||||
- `single_factor` [3][7]: assume single factor model as the shrinking target
|
||||
- np.ndarray: provide the shrinking targets directly
|
||||
- `const_var` [1][4][5]: assume stocks have the same constant variance and zero correlation.
|
||||
- `const_corr` [2][6]: assume stocks have different variance but equal correlation.
|
||||
- `single_factor` [3][7]: assume single factor model as the shrinking target.
|
||||
- np.ndarray: provide the shrinking targets directly.
|
||||
|
||||
Note:
|
||||
- The optimal shrinking parameter depends on the selection of the shrinking target.
|
||||
@@ -402,13 +402,13 @@ class POETCovEstimator(RiskModel):
|
||||
def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = "soft", **kwargs):
|
||||
"""
|
||||
Args:
|
||||
num_factors (int): number of factors (if set to zero, no factor model will be used)
|
||||
thresh (float): the positive constant for thresholding
|
||||
num_factors (int): number of factors (if set to zero, no factor model will be used).
|
||||
thresh (float): the positive constant for thresholding.
|
||||
thresh_method (str): thresholding method, which can be
|
||||
- 'soft': soft thresholding
|
||||
- 'hard': hard thresholding
|
||||
- 'scad': scad thresholding
|
||||
kwargs: see `RiskModel` for more information
|
||||
- 'soft': soft thresholding.
|
||||
- 'hard': hard thresholding.
|
||||
- 'scad': scad thresholding.
|
||||
kwargs: see `RiskModel` for more information.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
40
qlib/model/trainer.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
|
||||
|
||||
def task_train(config: dict, experiment_name):
|
||||
"""
|
||||
task based training
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : dict
|
||||
A dict describing the training process
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(config.get("task")["model"])
|
||||
dataset = init_instance_by_config(config.get("task")["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name=experiment_name):
|
||||
# train model
|
||||
R.log_params(**flatten_dict(config.get("task")))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
for record in config.get("task")["record"]:
|
||||
if record["class"] == SignalRecord.__name__:
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||