diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index eb2e7d09f..033d31536 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -43,16 +43,17 @@ jobs:
- name: Lint with Black
run: |
cd ..
- python -m black qlib -l 120
+ python -m black qlib -l 120 --check --diff
- name: Unit tests with Pytest
run: |
cd tests
pytest . --durations=0
- - name: Test data downloads and examples
+ - name: Test data downloads
run: |
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
- # cd examples
- # estimator -c estimator/estimator_config.yaml
- # jupyter nbconvert --execute estimator/analyze_from_estimator.ipynb --to html
\ No newline at end of file
+
+ - name: Test workflow by config
+ run: |
+ qrun examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
diff --git a/README.md b/README.md
index 5b385e563..535f77376 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
-
+
@@ -28,6 +28,8 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
- [Quant Model Zoo](#quant-model-zoo)
+ - [Run a single model](#run-a-single-model)
+ - [Run multiple models](#run-multiple-models)
- [Quant Dataset Zoo](#quant-dataset-zoo)
- [More About Qlib](#more-about-qlib)
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
@@ -39,19 +41,17 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
# Framework of Qlib
-

+
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
-| Name | Description |
-| ------ | ----- |
-| `Data layer` | `DataServer` focuses on providing high-performance infrastructure for users to manage and retrieve raw data. `DataEnhancement` will preprocess the data and provide the best dataset to be fed into the models. |
-| `Interday Model` | `Interday model` focuses on producing prediction scores (aka. _alpha_). Models are trained by `Model Creator` and managed by `Model Manager`. Users could choose one or multiple models for prediction. Multiple models could be combined with `Ensemble` module. |
-| `Interday Strategy` | `Portfolio Generator` will take prediction scores as input and output the orders based on the current position to achieve the target portfolio. |
-| `Intraday Trading` | `Order Executor` is responsible for executing orders output by `Interday Strategy` and returning the executed results. |
-| `Analysis` | Users could get a detailed analysis report of forecasting signals and portfolios in this part. |
+| Name | Description |
+| ------ | ----- |
+| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
+| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
+| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
* The modules with hand-drawn style are under development and will be released in the future.
* The modules with dashed borders are highly user-customizable and extendible.
@@ -128,50 +128,53 @@ Users could create the same dataset with it.
-->
## Auto Quant Research Workflow
-Qlib provides a tool named `Estimator` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
+Qlib provides a tool named `qrun` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). You can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
-1. Quant Research Workflow: Run `Estimator` with [estimator_config.yaml](examples/estimator/estimator_config.yaml) as following. (*Please note that this may **not work** under MacOS with Python 3.8 due to the incompatibility of the `sacred` package we use with Python 3.8. We will fix this bug in the future.*)
+1. Quant Research Workflow: Run `qrun` with lightgbm workflow config ([workflow_config_lightgbm.yaml](examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml)) as following.
```bash
cd examples # Avoid running program under the directory contains `qlib`
- estimator -c estimator/estimator_config.yaml
+ qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
```
- The result of `Estimator` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
+ The result of `qrun` is as follows, please refer to please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
```bash
- risk
- excess_return_without_cost mean 0.000675
- std 0.005456
- annualized_return 0.170077
- information_ratio 1.963824
- max_drawdown -0.063646
- excess_return_with_cost mean 0.000479
- std 0.005453
- annualized_return 0.120776
- information_ratio 1.395116
- max_drawdown -0.071216
+ 'The following are analysis results of the excess return without cost.'
+ risk
+ mean 0.000708
+ std 0.005626
+ annualized_return 0.178316
+ information_ratio 1.996555
+ max_drawdown -0.081806
+ 'The following are analysis results of the excess return with cost.'
+ risk
+ mean 0.000512
+ std 0.005626
+ annualized_return 0.128982
+ information_ratio 1.444287
+ max_drawdown -0.091078
```
- Here are detailed documents for [Estimator](https://qlib.readthedocs.io/en/latest/component/estimator.html).
+ Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html).
-2. Graphical Reports Analysis: Run `examples/estimator/analyze_from_estimator.ipynb` with `jupyter notebook` to get graphical reports
+2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
- Forecasting signal (model prediction) analysis
- Cumulative Return of groups
- 
+ 
- Return distribution
- 
+ 
- Information Coefficient (IC)
- 
- 
- 
+ 
+ 
+ 
- Auto Correlation of forecasting signal (model prediction)
- 
+ 
- Portfolio analysis
- Backtest return
- 
+ 
## Building Customized Quant Research Workflow by Code
-The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/train_backtest_analyze.ipynb) is a demo for customized Quant research workflow by code
+The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
-# Quant Model Zoo
+# [Quant Model Zoo](examples/benchmarks)
Here is a list of models built on `Qlib`.
-- [GBDT based on lightgbm](qlib/contrib/model/gbdt.py)
+- [GBDT based on LightGBM](qlib/contrib/model/gbdt.py)
+- [GBDT based on Catboost](qlib/contrib/model/catboost_model.py)
+- [GBDT based on XGBoost](qlib/contrib/model/xgboost.py)
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
+- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
+- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
+- [ALSTM based on pytorcn](qlib/contrib/model/pytorch_alstm.py)
+- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
+- [SFM based on pytorch](qlib/contrib/model/pytorch_sfm.py)
+
Your PR of new Quant models is highly welcomed.
+## Run a single model
+All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
+
+`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
+- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
+- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
+
+- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
+
+## Run multiple models
+`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only supprots *Linux* now. Other OS will be supported in the future.)
+
+The script will create a unique virtual environment for each model, and delete the environments after training. Thus, only experiment results such as `IC` and `backtest` results will be generated and stored. (**Note**: the script will erase your previous experiment records created by running itself.)
+
+Here is an example of running all the models for 10 iterations:
+```python
+python run_all_model.py 10
+```
+
+It also provides the API to run specific models at once. For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
+
+
# Quant Dataset Zoo
Dataset plays a very important role in Quant. Here is a list of the datasets built on `Qlib`.
-- [Alpha360](./qlib/contrib/estimator/handler.py)
-- [Alpha158](./qlib/contrib/estimator/handler.py)
+
+| Dataset | US Market | China Market |
+| -- | -- | -- |
+| [Alpha360](./qlib/contrib/data/handler.py) | √ | √ |
+| [Alpha158](./qlib/contrib/data/handler.py) | √ | √ |
[Here](https://qlib.readthedocs.io/en/latest/advanced/alpha.html) is a tutorial to build dataset with `Qlib`.
Your PR to build new Quant dataset is highly welcomed.
diff --git a/docs/FAQ/FAQ.rst b/docs/FAQ/FAQ.rst
index 8ea5e9f82..ba6f77b47 100644
--- a/docs/FAQ/FAQ.rst
+++ b/docs/FAQ/FAQ.rst
@@ -62,6 +62,7 @@ It sees the key of the redis lock has existed in your redis db now. You can use
> select 1
> flushdb
+If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If so, try using ``flushall`` to clear all the keys.
.. note::
diff --git a/docs/_static/img/analysis/analysis_model_IC.png b/docs/_static/img/analysis/analysis_model_IC.png
index 0064fb890..26b4b4bfa 100644
Binary files a/docs/_static/img/analysis/analysis_model_IC.png and b/docs/_static/img/analysis/analysis_model_IC.png differ
diff --git a/docs/_static/img/analysis/analysis_model_NDQ.png b/docs/_static/img/analysis/analysis_model_NDQ.png
index c1824368b..5197c4b03 100644
Binary files a/docs/_static/img/analysis/analysis_model_NDQ.png and b/docs/_static/img/analysis/analysis_model_NDQ.png differ
diff --git a/docs/_static/img/analysis/analysis_model_auto_correlation.png b/docs/_static/img/analysis/analysis_model_auto_correlation.png
index 3f213a79b..ab9e30165 100644
Binary files a/docs/_static/img/analysis/analysis_model_auto_correlation.png and b/docs/_static/img/analysis/analysis_model_auto_correlation.png differ
diff --git a/docs/_static/img/analysis/analysis_model_cumulative_return.png b/docs/_static/img/analysis/analysis_model_cumulative_return.png
index bcccf138a..c305a42b4 100644
Binary files a/docs/_static/img/analysis/analysis_model_cumulative_return.png and b/docs/_static/img/analysis/analysis_model_cumulative_return.png differ
diff --git a/docs/_static/img/analysis/analysis_model_long_short.png b/docs/_static/img/analysis/analysis_model_long_short.png
index 2fcb08c4e..5efed2d6c 100644
Binary files a/docs/_static/img/analysis/analysis_model_long_short.png and b/docs/_static/img/analysis/analysis_model_long_short.png differ
diff --git a/docs/_static/img/analysis/analysis_model_monthly_IC.png b/docs/_static/img/analysis/analysis_model_monthly_IC.png
index 0056c6c9c..8443f3860 100644
Binary files a/docs/_static/img/analysis/analysis_model_monthly_IC.png and b/docs/_static/img/analysis/analysis_model_monthly_IC.png differ
diff --git a/docs/_static/img/analysis/report.png b/docs/_static/img/analysis/report.png
index dfd227f5a..2901da603 100644
Binary files a/docs/_static/img/analysis/report.png and b/docs/_static/img/analysis/report.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_annualized_return.png b/docs/_static/img/analysis/risk_analysis_annualized_return.png
index 1979ca19b..18e7a90aa 100644
Binary files a/docs/_static/img/analysis/risk_analysis_annualized_return.png and b/docs/_static/img/analysis/risk_analysis_annualized_return.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_bar.png b/docs/_static/img/analysis/risk_analysis_bar.png
index 1cce1f340..c90650a6d 100644
Binary files a/docs/_static/img/analysis/risk_analysis_bar.png and b/docs/_static/img/analysis/risk_analysis_bar.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_information_ratio.png b/docs/_static/img/analysis/risk_analysis_information_ratio.png
index edc64b17d..7028eaf02 100644
Binary files a/docs/_static/img/analysis/risk_analysis_information_ratio.png and b/docs/_static/img/analysis/risk_analysis_information_ratio.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_max_drawdown.png b/docs/_static/img/analysis/risk_analysis_max_drawdown.png
index a68810222..b7f1ae130 100644
Binary files a/docs/_static/img/analysis/risk_analysis_max_drawdown.png and b/docs/_static/img/analysis/risk_analysis_max_drawdown.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_std.png b/docs/_static/img/analysis/risk_analysis_std.png
index 73d782e20..6f38def26 100644
Binary files a/docs/_static/img/analysis/risk_analysis_std.png and b/docs/_static/img/analysis/risk_analysis_std.png differ
diff --git a/docs/_static/img/analysis/score_ic.png b/docs/_static/img/analysis/score_ic.png
index 6e1d37d2a..a5739a9ba 100644
Binary files a/docs/_static/img/analysis/score_ic.png and b/docs/_static/img/analysis/score_ic.png differ
diff --git a/docs/_static/img/framework.png b/docs/_static/img/framework.png
index 673f10e03..d8242f7c1 100644
Binary files a/docs/_static/img/framework.png and b/docs/_static/img/framework.png differ
diff --git a/docs/advanced/alpha.rst b/docs/advanced/alpha.rst
index be30ea8a7..e6146dd0c 100644
--- a/docs/advanced/alpha.rst
+++ b/docs/advanced/alpha.rst
@@ -1,4 +1,5 @@
.. _alpha:
+
===========================
Building Formulaic Alphas
===========================
@@ -49,7 +50,7 @@ Users can use ``Data Handler`` to build formulaic alphas `MACD` in qlib:
.. code-block:: python
- >> from qlib.contrib.estimator.handler import QLibDataHandler
+ >> from qlib.data.dataset.handler import QLibDataHandler
>> MACD_EXP = '(EMA($close, 12) - EMA($close, 26))/$close - EMA((EMA($close, 12) - EMA($close, 26))/$close, 9)/$close'
>> fields = [MACD_EXP] # MACD
>> names = ['MACD']
diff --git a/docs/advanced/server.rst b/docs/advanced/server.rst
index 230c4f04b..a8a764b91 100644
--- a/docs/advanced/server.rst
+++ b/docs/advanced/server.rst
@@ -1,4 +1,5 @@
.. _server:
+
=================================
``Online`` & ``Offline`` mode
=================================
diff --git a/docs/component/backtest.rst b/docs/component/backtest.rst
index fd4ac19fa..d36dba316 100644
--- a/docs/component/backtest.rst
+++ b/docs/component/backtest.rst
@@ -1,4 +1,5 @@
.. _backtest:
+
============================================
Intraday Trading: Model&Strategy Testing
============================================
diff --git a/docs/component/data.rst b/docs/component/data.rst
index 195ed482d..55d6c7207 100644
--- a/docs/component/data.rst
+++ b/docs/component/data.rst
@@ -1,6 +1,7 @@
.. _data:
+
================================
-Data Layer: Data Framework&Usage
+Data Layer: Data Framework & Usage
================================
Introduction
@@ -14,7 +15,9 @@ The introduction of ``Data Layer`` includes the following parts.
- Data Preparation
- Data API
+- Data Loader
- Data Handler
+- Dataset
- Cache
- Data and Cache File Structure
@@ -30,13 +33,19 @@ Such data will be stored with filename suffix `.bin` (We'll call them `.bin` fil
Qlib Format Dataset
--------------------
-``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the dataset as follows.
+``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
-After running the above command, users can find china-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory.
+In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
+
+.. code-block:: bash
+
+ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
+
+After running the above command, users can find china-stock and us-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
@@ -48,12 +57,45 @@ Converting CSV Format into Qlib Format
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert data in CSV format into `.bin` files (Qlib format).
-Users can download the china-stock data in CSV format as follows for reference to the CSV format.
+Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
.. code-block:: bash
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
+Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
+
+- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
+
+ - Name the CSV file after a stock: `SH600000.csv`, `AAPL.csv` (not case sensitive).
+
+ - CSV file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
+
+ .. code-block:: bash
+
+ python scripts/dump_bin.py dump_all ... --symbol_field_name symbol
+
+ where the data are in the following format:
+
+ .. code-block::
+
+ symbol,close
+ SH600000,120
+
+- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
+
+ .. code-block:: bash
+
+ python scripts/dump_bin.py dump_all ... --date_field_name date
+
+ where the data are in the following format:
+
+ .. code-block::
+
+ symbol,date,close,open,volume
+ SH600000,2020-11-01,120,121,12300000
+ SH600000,2020-11-02,123,120,12300000
+
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
@@ -61,6 +103,12 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
+For other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:
+
+.. code-block:: bash
+
+ python dump_bin.py dump_all --help
+
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.
.. note::
@@ -96,9 +144,8 @@ China-Stock Mode & US-Stock Mode
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=REG_CN)
-- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` does not provide a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
- - Prepare data in CSV format
- - Convert data from CSV format to Qlib format, please refer to section `Converting CSV Format into Qlib Format <#converting-csv-format-into-qlib-format>`_.
+- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
+ - Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Initialize ``Qlib`` in US-stock mode
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
@@ -140,72 +187,97 @@ Filter
Expression dynamic instrument filter. Filter the instruments based on a certain expression. An expression rule indicating a certain feature field is required.
- `basic features filter`: rule_expression = '$close/$open>5'
- - `cross-sectional features filter` : rule_expression = '$rank($close)<10'
+ - `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
-
Reference
-------------
To know more about ``Data API``, please refer to `Data API <../reference/api.html#data>`_.
+
+Data Loader
+=================
+
+``Data Loader`` in ``Qlib`` is designed to load raw data from the original data source. It will be loaded and used in the ``Data Handler`` module.
+
+QlibDataLoader
+---------------
+
+The ``QlibDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from the data source.
+
+Interface
+------------
+
+Here are some interfaces of the ``QlibDataLoader`` class:
+
+.. autoclass:: qlib.data.dataset.loader.QlibDataLoader
+ :members: load, load_group_df
+
+API
+-----------
+
+To know more about ``Data Loader``, please refer to `Data Loader API <../reference/api.html#module-qlib.data.dataset.loader>`_.
+
+
Data Handler
=================
-Users can use ``Data Handler`` in an automatic workflow by ``Estimator``, refer to `Estimator: Workflow Management `_ for more details.
+The ``Data Handler`` module in ``Qlib`` is designed to handler those common data processing methods which will be used by most of the models.
-Also, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data(standardization, remove NaN, etc.) and build datasets. It is a subclass of ``qlib.contrib.estimator.handler.BaseDataHandler``, which provides some interfaces as follows.
+Users can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management `_ for more details.
-Base Class & Interface
+DataHandlerLP
+--------------
+
+In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
+
+In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some leanable ``Processors`` which can learn the parameters of data processing. When new data comes in, these `trained` ``Processors`` can then infer on the new data and thus processing real-time data in an efficient way. More information about ``Processors`` will be listed in the next subsection.
+
+
+Interface
----------------------
-Qlib provides a base class `qlib.contrib.estimator.BaseDataHandler <../reference/api.html#qlib.contrib.estimator.handler.BaseDataHandler>`_, which provides the following interfaces:
+Here are some important interfaces that ``DataHandlerLP`` provides:
-- `setup_feature`
- Implement the interface to load the data features.
+.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
+ :members: __init__, fetch, get_cols
-- `setup_label`
- Implement the interface to load the data labels and calculate the users' labels.
+If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
-- `setup_processed_data`
- Implement the interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on.
-
-Qlib also provides two functions to help users init the data handler, users can override them for users' needs.
-
-- `_init_kwargs`
- Users can init the kwargs of the data handler in this function, some kwargs may be used when init the raw df.
- Kwargs are the other attributes in data.args, like dropna_label, dropna_feature
-
-- `_init_raw_df`
- Users can init the raw df, feature names, and label names of data handler in this function.
- If the index of feature df and label df are not the same, users need to override this method to merge them (e.g. inner, left, right merge).
-
-If users want to load features and labels by config, users can inherit ``qlib.contrib.estimator.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
-Usage
---------------
+Processor
+----------
-``Data Handler`` can be used as a single module, which provides the following mehtods:
+The ``Processor`` module in ``Qlib`` is designed to be learnable and it is responsible for handling data processing such as `normalization` and `drop none/nan features/labels`.
-- `get_split_data`
- - According to the start and end dates, return features and labels of the pandas DataFrame type used for the 'Model'
-
-- `get_rolling_data`
- - According to the start and end dates, and `rolling_period`, an iterator is returned, which can be used to traverse the features and labels used for rolling.
+``Qlib`` provides the following ``Processors``:
+- ``DropnaProcessor``: `processor` that drops N/A features.
+- ``DropnaLabel``: `processor` that drops N/A labels.
+- ``TanhProcess``: `processor` that uses `tanh` to process noise data.
+- ``ProcessInf``: `processor` that handles infinity values, it will be replaces by the mean of the column.
+- ``Fillna``: `processor` that handles N/A values, which will fill the N/A value by 0 or other given number.
+- ``MinMaxNorm``: `processor` that applies min-max normalization.
+- ``ZscoreNorm``: `processor` that applies z-score normalization.
+- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.
+- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.
+- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
+Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link `_).
+To know more about ``Processor``, please refer to `Processor API <../reference/api.html#module-qlib.data.dataset.processor>`_.
Example
--------------
-``Data Handler`` can be run with ``estimator`` by modifying the configuration file, and can also be used as a single module.
+``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module.
-Know more about how to run ``Data Handler`` with ``Estimator``, please refer to `Estimator: Workflow Management `_
+Know more about how to run ``Data Handler`` with ``qrun``, please refer to `Workflow: Workflow Management `_
Qlib provides implemented data handler `Alpha158`. The following example shows how to run `Alpha158` as a single module.
@@ -214,44 +286,54 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
.. code-block:: Python
- from qlib.contrib.estimator.handler import Alpha158
- from qlib.contrib.model.gbdt import LGBModel
+ import qlib
+ from qlib.contrib.data.handler import Alpha158
- DATA_HANDLER_CONFIG = {
- "dropna_label": True,
- "start_date": "2007-01-01",
- "end_date": "2020-08-01",
- "market": "csi300",
+ data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": "csi300",
}
- TRAINER_CONFIG = {
- "train_start_date": "2007-01-01",
- "train_end_date": "2014-12-31",
- "validate_start_date": "2015-01-01",
- "validate_end_date": "2016-12-31",
- "test_start_date": "2017-01-01",
- "test_end_date": "2020-08-01",
- }
+ if __name__ == "__main__":
+ qlib.init()
+ h = Alpha158(**data_handler_config)
- exampleDataHandler = Alpha158(**DATA_HANDLER_CONFIG)
+ # get all the columns of the data
+ print(h.get_cols())
- # example of 'get_split_data'
- x_train, y_train, x_validate, y_validate, x_test, y_test = exampleDataHandler.get_split_data(**TRAINER_CONFIG)
+ # fetch all the labels
+ print(h.fetch(col_set="label"))
- # example of 'get_rolling_data'
-
- for (x_train, y_train, x_validate, y_validate, x_test, y_test) in exampleDataHandler.get_rolling_data(**TRAINER_CONFIG):
- print(x_train, y_train, x_validate, y_validate, x_test, y_test)
-
-
-.. note:: (x_train, y_train, x_validate, y_validate, x_test, y_test) can be used as arguments for the `fit`, `predic``, and `score` methods of the ``Interday Model`` , please refer to `Model `_.
-
-Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``.
+ # fetch all the features
+ print(h.fetch(col_set="feature"))
API
---------
-To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.contrib.estimator.handler>`_.
+To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.data.dataset.handler>`_.
+
+
+Dataset
+=================
+
+The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
+
+The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the rights to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``MLP`` will break down on such data.
+
+The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most important interface of the class:
+
+.. autoclass:: qlib.data.dataset.__init__.DatasetH
+ :members:
+
+API
+---------
+
+To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
+
+
Cache
==========
diff --git a/docs/component/estimator.rst b/docs/component/estimator.rst
deleted file mode 100644
index 917d73c13..000000000
--- a/docs/component/estimator.rst
+++ /dev/null
@@ -1,706 +0,0 @@
-.. _estimator:
-=================================
-Estimator: Workflow Management
-=================================
-.. currentmodule:: qlib
-
-Introduction
-===================
-
-The components in `Qlib Framework <../introduction/introduction.html#framework>`_ are designed in a loosely-coupled way. Users could build their own Quant research workflow with these components like `Example `_
-
-
-Besides, ``Qlib`` provides more user-friendly interfaces named ``Estimator`` to automatically run the whole workflow defined by configuration. A concrete execution of the whole workflow is called an `experiment`.
-With ``Estimator``, user can easily run an `experiment`, which includes the following steps:
-
-- Data
- - Loading
- - Processing
- - Slicing
-- Model
- - Training and inference(static or rolling)
- - Saving & loading
-- Evaluation(Back-testing)
-
-For each `experiment`, ``Qlib`` will capture the model training details, performance evaluation results and basic information (e.g. names, ids). The captured data will be stored in backend-storage (disk or database).
-
-Complete Example
-===================
-
-Before getting into details, here is a complete example of ``Estimator``, which defines the workflow in typical Quant research.
-Below is a typical config file of ``Estimator``.
-
-.. code-block:: YAML
-
- experiment:
- name: estimator_example
- observer_type: file_storage
- mode: train
- model:
- class: LGBModel
- module_path: qlib.contrib.model.gbdt
- args:
- loss: mse
- colsample_bytree: 0.8879
- learning_rate: 0.0421
- subsample: 0.8789
- lambda_l1: 205.6999
- lambda_l2: 580.9768
- max_depth: 8
- num_leaves: 210
- num_threads: 20
- data:
- class: Alpha158
- args:
- dropna_label: True
- filter:
- market: csi500
- trainer:
- class: StaticTrainer
- args:
- rolling_period: 360
- train_start_date: 2007-01-01
- train_end_date: 2014-12-31
- validate_start_date: 2015-01-01
- validate_end_date: 2016-12-31
- test_start_date: 2017-01-01
- test_end_date: 2020-08-01
- strategy:
- class: TopkDropoutStrategy
- args:
- topk: 50
- n_drop: 5
- backtest:
- normal_backtest_args:
- verbose: False
- limit_threshold: 0.095
- account: 100000000
- benchmark: SH000905
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
- qlib_data:
- # when testing, please modify the following parameters according to the specific environment
- provider_uri: "~/.qlib/qlib_data/cn_data"
- region: "cn"
-
-After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
-
-.. code-block:: bash
-
- estimator -c configuration.yaml
-
-.. note:: `estimator` will be placed in your $PATH directory when installing ``Qlib``.
-
-
-
-Configuration File
-===================
-
-Let's get into details of ``Estimator`` in this section.
-
-Before using ``estimator``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.
-
-Experiment Section
---------------------
-
-At first, the configuration file needs to contain a section named `experiment` about the basic information. This section describes how `estimator` tracks and persists current `experiment`. ``Qlib`` used `sacred`, a lightweight open-source tool, to configure, organize, generate logs, and manage experiment results. Partial behaviors of `sacred` will base on the `experiment` section.
-
-Following files will be saved by `sacred` after `estimator` finish an `experiment`:
-
-- `model.bin`, model binary file
-- `pred.pkl`, model prediction result file
-- `analysis.pkl`, backtest performance analysis file
-- `positions.pkl`, backtest position records file
-- `run`, the experiment information object, usually contains some meta information such as the experiment name, experiment date, etc.
-
-Here is the typical configuration of `experiment section`
-
-.. code-block:: YAML
-
- experiment:
- name: test_experiment
- observer_type: mongo
- mongo_url: mongodb://MONGO_URL
- db_name: public
- finetune: false
- exp_info_path: /home/test_user/exp_info.json
- mode: test
- loader:
- id: 677
-
-
-The meaning of each field is as follows:
-
-- `name`
- The experiment name, str type, `sacred _` will use this experiment name as an identifier for some important internal processes. Users can find this field in `run` object of `sacred`. The default value is `test_experiment`.
-
-- `observer_type`
- Observer type, str type, there are two choices which include `file_storage` and `mongo` respectively. If `file_storage` is selected, all the above-mentioned managed contents will be stored in the `dir` directory, separated by the number of times of experiments as a subfolder. If it is `mongo`, the content will be stored in the database. The default is `file_storage`.
-
- - For `file_storage` observer.
- - `dir`
- Directory URL, str type, directory for `file_storage` observer type, files captured and managed by sacred with `file_storage` observer will be saved to this directory, which is the same directory as `config.json` by default.
-
- - For `mongo` observer.
- - `mongo_url`
- Database URL, str type, required if the observer type is `mongo`.
-
- - `db_name`
- Database name, str type, required if the observer type is `mongo`.
-
-- `finetune`
- ``Estimator``'s behaviors to train models will base on this flag.
- If you just want to train models from scratch each time instead of based on existing models, please leave `finetune=false`. Otherwise please read the
- details below.
-
- The following table is the processing logic for different situations.
-
- ========== =========================================== ==================================== =========================================== ==========================================
- . Static Rolling
- . finetune:true finetune:false finetune:true finetune:false
- ========== =========================================== ==================================== =========================================== ==========================================
- Train - Need to provide model (Static or Rolling) - No need to provide model - Need to provide model (Static or Rolling) - Need to provide model (Static or Rolling)
- - The args in model section will be - The args in model section will be - The args in model section will be - The args in model section will be
- used for finetuning used for training used for finetuning used for finetuning
- - Update based on the provided model - Train model from scratch - Update based on the provided model - Based on the provided model update
- and parameters and parameters - Train model from scratch
- - **Each rolling time slice is based on** - **Train each rolling time slice**
- **a model updated from the previous** **separately**
- **time**
- Test - Model must exist, otherwise an exception will be raised.
- - For `StaticTrainer`, users need to train a model and record 'exp_info' for 'Test'.
- - For `RollingTrainer`, users need to train a set of models until the latest time, and record 'exp_info' for 'Test'.
- ========== =============================================================================================================================================================================
-
- .. note::
-
- 1. finetune parameters: share model.args parameters.
-
- 2. provide model: from `loader.model_index`, load the index of the model(starting from 0).
-
- 3. If `loader.model_index` is None:
- - In 'Static Finetune=True', if provide 'Rolling', use the last model to update.
-
- - For `RollingTrainer` with Finetune=True.
-
- - If `StaticTrainer` is used in loader, the model will be used for initialization for finetuning.
-
- - If `RollingTrainer` is used in loader, the existing models will be used without any modification and the new models will be initialized with the model in the last period and finetune one by one.
-
-
-- `exp_info_path`
- save path of experiment info, str type, save the experiment info and model `prediction score` after the experiment is finished. Optional parameter, the default value is `/ex_name/exp_info.json`.
-
-- `mode`
- `train` or `test`, str type.
- - `test mode` is designed for inference. Under `test mode`, it will load the model according to the parameters of `loader` and skip model training.
- - `train model` is the default value. It will train new models by default and
- Please note that when it fails to load model, it will fall back to `fit` model.
-
- .. note::
-
- if users choose ` test mode`, they need to make sure:
- - The loader of `test_start_date` must be less than or equal to the current `test_start_date`.
- - If other parameters of the `loader` model args are different, a warning will appear.
-
-
-- `loader`
- If you just want to train models from scratch each time instead of based on existing models, please ignore `loader` section. Otherwise please read the
- details below.
-
- The `loader` section only works when the `mode` is `test` or `finetune` is `true`.
-
- - `model_index`
- Model index, int type. The index of the loaded model in loader_models (starting at 0) for the first `finetune`. The default value is None.
-
- - `exp_info_path`
- Loader model experiment info path, str type. If the field exists, the following parameters will be parsed from `exp_info_path`, and the following parameters will not work. One of this field and `id` must exist at least .
-
- - `id`
- The experiment id of the model that needs to be loaded, int type. If the `mode` is `test`, this value is required. This field and `exp_info_path` must exist one.
-
- - `name`
- The experiment name of the model that needs to be loaded, str type. The default value is the current experiment `name`.
-
- - `observer_type`
- The experiment observer type of the model that needs to be loaded, str type. The default value is the current experiment `observer_type`.
-
- .. note:: The observer type is a concept of the `sacred` module, which determines how files, standard input, and output which are managed by sacred are stored.
-
-
- - `file_storage`
- If `observer_type` is `file_storage`, the config may be as follows.
-
- .. code-block:: YAML
-
- experiment:
- name: test_experiment
- dir: # default is dir of `config.yml`
- observer_type: file_storage
- - `mongo`
- If `observer_type` is `mongo`, the config may be as follows.
-
- .. code-block:: YAML
-
- experiment:
- name: test_experiment
- observer_type: mongo
- mongo_url: mongodb://MONGO_URL
- db_name: public
-
- Users need to indicate `mongo_url` and `db_name` for a mongo observer.
-
- .. note::
-
- If users choose the mongo observer, they need to make sure:
- - Have an environment with the mongodb installed and a mongo database dedicated to storing the results of the experiments.
- - The python environment (the version of python and package) to run the experiments and the one to fetch the results are consistent.
-
-Model Section
------------------
-
-Users can use a specified model by configuration with hyper-parameters.
-
-Custom Models
-~~~~~~~~~~~~~~~~~
-
-Qlib supports custom models, but it must be a subclass of the `qlib.contrib.model.Model`, the config for a custom model may be as following.
-
-.. code-block:: YAML
-
- model:
- class: SomeModel
- module_path: /tmp/my_experment/custom_model.py
- args:
- loss: binary
-
-
-The class `SomeModel` should be in the module `custom_model`, and ``Qlib`` could parse the `module_path` to load the class.
-
-To know more about ``Interday Model``, please refer to `Interday Model: Training & Prediction `_.
-
-Data Section
------------------
-
-``Data Handler`` can be used to load raw data, prepare features and label columns, preprocess data (standardization, remove NaN, etc.), split training, validation, and test sets. It is a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`.
-
-Users can use the specified data handler by config as follows.
-
-.. code-block:: YAML
-
- data:
- class: Alpha158
- args:
- start_date: 2005-01-01
- end_date: 2018-04-30
- dropna_label: True
- filter:
- market: csi500
- filter_pipeline:
- -
- class: NameDFilter
- module_path: qlib.filter
- args:
- name_rule_re: S(?!Z3)
- fstart_time: 2018-01-01
- fend_time: 2018-12-11
- -
- class: ExpressionDFilter
- module_path: qlib.filter
- args:
- rule_expression: $open/$factor<=45
- fstart_time: 2018-01-01
- fend_time: 2018-12-11
-
-- `class`
- Data handler class, str type, which should be a subclass of `qlib.contrib.estimator.handler.BaseDataHandler`, and implements 5 important interfaces for loading features, loading raw data, preprocessing raw data, slicing train, validation, and test data. The default value is `ALPHA360`. If users want to write a data handler to retrieve the data in ``Qlib``, `QlibDataHandler` is suggested.
-
-- `module_path`
- The module path, str type, absolute url is also supported, indicates the path of the `class` implementation of the data processor class. The default value is `qlib.contrib.estimator.handler`.
-
-- `args`
- Parameters used for ``Data Handler`` initialization.
-
- - `train_start_date`
- Training start time, str type, the default value is `2005-01-01`.
-
- - `start_date`
- Data start date, str type.
-
- - `end_date`
- Data end date, str type. the data from start_date to end_date decides which part of data will be loaded in `datahandler`, users can only use these data in the following parts.
-
- - `dropna_feature` (Optional in args)
- Drop Nan feature, bool type, the default value is False.
-
- - `dropna_label` (Optional in args)
- Drop Nan label, bool type, the default value is True. Some multi-label tasks will use this.
-
- - `normalize_method` (Optional in args)
- Normalize data by a given method. str type. ``Qlib`` gives two normalizing methods, `MinMax` and `Std`.
- If users want to build their own method, please override `_process_normalize_feature`.
-
-- `filter`
- Dynamically filtering the stocks based on the filter pipeline.
-
- - `market`
- index name, str type, the default value is `csi500`.
-
- - `filter_pipeline`
- Filter rule list, list type, the default value is []. Can be customized according to users' needs.
-
- - `class`
- Filter class name, str type.
-
- - `module_path`
- The module path, str type.
-
- - `args`
- The filter class parameters, these parameters are set according to the `class`, and all the parameters as kwargs to `class`.
-
-Custom Data Handler
-~~~~~~~~~~~~~~~~~~~~~~
-
-Qlib support custom data handler, but it must be a subclass of the ``qlib.contrib.estimator.handler.BaseDataHandler``, the config for custom data handler may be as follows.
-
-.. code-block:: YAML
-
- data:
- class: SomeDataHandler
- module_path: /tmp/my_experment/custom_data_handler.py
- args:
- start_date: 2005-01-01
- end_date: 2018-04-30
-
-The class `SomeDataHandler` should be in the module `custom_data_handler`, and ``Qlib`` could parse the `module_path` to load the class.
-
-If users want to load features and labels by config, they can inherit ``qlib.contrib.estimator.handler.ConfigDataHandler``, ``Qlib`` also has provided some preprocess methods in this subclass.
-If users want to use qlib data, `QLibDataHandler` is recommended, from which users can inherit the custom class. `QLibDataHandler` is also a subclass of `ConfigDataHandler`.
-
-To know more about ``Data Handler``, please refer to `Data Framework&Usage `_.
-
-Trainer Section
------------------
-
-Users can specify the trainer ``Trainer`` by the config file, which is a subclass of ``qlib.contrib.estimator.trainer.BaseTrainer`` and implement three important interfaces for training the model, restoring the model, and getting model predictions as follows.
-
-- `train`
- Implement this interface to train the model.
-
-- `load`
- Implement this interface to recover the model from disk.
-
-- `get_pred`
- Implement this interface to get model prediction results.
-
-Qlib have provided two implemented trainer,
-
-- `StaticTrainer`
- The static trainer will be trained using the training, validation, and test data of the data processor static slicing.
-
-- `RollingTrainer`
- The rolling trainer will use the rolling iterator of the data processor to split data for rolling training.
-
-
-Users can specify `trainer` with the configuration file:
-
-.. code-block:: YAML
-
- trainer:
- class: StaticTrainer # or RollingTrainer
- args:
- rolling_period: 360
- train_start_date: 2005-01-01
- train_end_date: 2014-12-31
- validate_start_date: 2015-01-01
- validate_end_date: 2016-06-30
- test_start_date: 2016-07-01
- test_end_date: 2017-07-31
-
-- `class`
- Trainer class, which should be a subclass of `qlib.contrib.estimator.trainer.BaseTrainer`, and needs to implement three important interfaces, the default value is `StaticTrainer`.
-
-- `module_path`
- The module path, str type, absolute url is also supported, indicates the path of the trainer class implementation.
-
-- `args`
- Parameters used for ``Trainer`` initialization.
-
- - `rolling_period`
- The rolling period, integer type, indicates how many time steps need rolling when rolling the data. The default value is `60`. Only used in `RollingTrainer`.
-
- - `train_start_date`
- Training start time, str type.
-
- - `train_end_date`
- Training end time, str type.
-
- - `validate_start_date`
- Validation start time, str type.
-
- - `validate_end_date`
- Validation end time, str type.
-
- - `test_start_date`
- Test start time, str type.
-
- - `test_end_date`
- Test end time, str type. If `test_end_date` is `-1` or greater than the last date of the data, the last date of the data will be used as `test_end_date`.
-
-Custom Trainer
-~~~~~~~~~~~~~~~~~~
-
-Qlib supports custom trainer, but it must be a subclass of the `qlib.contrib.estimator.trainer.BaseTrainer`, the config for a custom trainer may be as following:
-
-.. code-block:: YAML
-
- trainer:
- class: SomeTrainer
- module_path: /tmp/my_experment/custom_trainer.py
- args:
- train_start_date: 2005-01-01
- train_end_date: 2014-12-31
- validate_start_date: 2015-01-01
- validate_end_date: 2016-06-30
- test_start_date: 2016-07-01
- test_end_date: 2017-07-31
-
-
-The class `SomeTrainer` should be in the module `custom_trainer`, and ``Qlib`` could parse the `module_path` to load the class.
-
-Strategy Section
------------------
-
-Users can specify strategy through a config file, for example:
-
-.. code-block:: YAML
-
- strategy :
- class: TopkDropoutStrategy
- args:
- topk: 50
- n_drop: 5
-
-- `class`
- The strategy class, str type, should be a subclass of `qlib.contrib.strategy.strategy.BaseStrategy`. The default value is `TopkDropoutStrategy`.
-
-- `module_path`
- The module location, str type, absolute url is also supported, and absolute path is also supported, indicates the location of the policy class implementation.
-
-- `args`
- Parameters used for ``Trainer`` initialization.
-
- - `topk`
- The number of stocks in the portfolio
-
- - `n_drop`
- Number of stocks to be replaced in each trading date
-
-Custom Strategy
-^^^^^^^^^^^^^^^^^^^
-
-Qlib supports custom strategy, but it must be a subclass of the ``qlib.contrib.strategy.strategy.BaseStrategy``, the config for custom strategy may be as following:
-
-
-.. code-block:: YAML
-
- strategy :
- class: SomeStrategy
- module_path: /tmp/my_experment/custom_strategy.py
-
-The class `SomeStrategy` should be in the module `custom_strategy`, and ``Qlib`` could parse the `module_path` to load the class.
-
-To know more about ``Strategy``, please refer to `Strategy `_.
-
-Backtest Section
------------------
-
-Users can specify `backtest` through a config file, for example:
-
-.. code-block:: YAML
-
- backtest :
- normal_backtest_args:
- topk: 50
- benchmark: SH000905
- account: 500000
- deal_price: close
- min_cost: 5
- subscribe_fields:
- - $close
- - $change
- - $factor
-
-- `normal_backtest_args`
- Normal backtest parameters. All the parameters in this section will be passed to the ``qlib.contrib.evaluate.backtest`` function in the form of `**kwargs`.
-
- - `benchmark`
- Stock index symbol, str, or list type, the default value is `None`.
-
- .. note::
-
- * If `benchmark` is None, it will use the average change of the day of all stocks in 'pred' as the 'bench'.
-
- * If `benchmark` is list, it will use the daily average change of the stock pool in the list as the 'bench'.
-
- * If `benchmark` is str, it will use the daily change as the 'bench'.
-
-
- - `account`
- Backtest initial cash, integer type. The `account` in `strategy` section is deprecated. It only works when `account` is not set in `backtest` section. It will be overridden by `account` in the `backtest` section. The default value is 1e9.
-
- - `deal_price`
- Order transaction price field, str type, the default value is close.
-
- - `min_cost`
- Min transaction cost, float type, the default value is 5.
-
- - `subscribe_fields`
- Subscribe quote fields, array type, the default value is [`deal_price`, $close, $change, $factor].
-
-
-Qlib Data Section
---------------------
-
-The `qlib_data` field describes the parameters of qlib initialization.
-
-.. code-block:: YAML
-
- qlib_data:
- # when testing, please modify the following parameters according to the specific environment
- provider_uri: "~/.qlib/qlib_data/cn_data"
- region: "cn"
-
-- `provider_uri`
- Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
-- `region`
- - If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
- - If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
-- `redis_host`
- Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
- The lock and cache mechanism relies on redis.
-- `redis_port`
- Type: int, optional parameter(default: 6379), port of `redis`
-
- .. note::
-
- The value of `region` should be aligned with the data stored in `provider_uri`. Currently, ``scripts/get_data.py`` only provides China stock market data. If users want to use the US stock market data, they should prepare their own US-stock data in `provider_uri` and switch to US-stock mode.
-
- .. note::
-
- If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache `_ for details.
-
-
-Please refer to `Initialization <../start/initialization.html>`_.
-
-Experiment Result
-===================
-
-Form of Experimental Result
-----------------------------
-The result of the experiment is also the result of the ``Intraday Trading(Backtest)``, please refer to `Intraday Trading: Model&Strategy Testing `_.
-
-
-Get Experiment Result
-----------------------------
-
-Base Class & Interface
-~~~~~~~~~~~~~~~~~~~~~~~
-
-Users can check the experiment results from file storage directly, or check the experiment results from the database, or get the experiment results through two interfaces of a base class `Fetcher` provided by ``Qlib``.
-
-The `Fetcher` provides the following interface
- - `get_experiments(self, exp_name=None):`
- The interface takes one parameters. The `exp_name` is the experiment name, the default is all experiments. Users can get the returned dictionary with a list of ids and test end date as follows.
-
- .. code-block:: JSON
-
- {
- "ex_a": [
- {
- "id": 1,
- "test_end_date": "2017-01-01"
- }
- ],
- "ex_b": [
- ...
- ]
- }
-
-
- - `get_experiment(exp_name, exp_id, fields=None)`
- The interface takes three parameters. The first parameter is the experiment name, the second parameter is the experiment id, and the third parameter is a list of fields. The default value of `fields` is None, which means all fields.
-
-
- .. note::
- Currently supported fields:
- ['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
-
- Users can get the returned dictionary as follows.
-
- .. code-block:: JSON
-
- {
- 'analysis': analysis_df,
- 'pred': pred_df,
- 'positions': positions_dic,
- 'report_normal': report_normal_df,
- }
-
-Implemented `Fetcher` s & Examples
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-``Qlib`` provides two implemented `Fetcher` s as follows.
-
-`FileFetcher`
-^^^^^^^^^^^^^^^
-
-The `FileFetcher` is a subclass of `Fetcher`, which could fetch files from `file_storage` observer. The following is an example:
-.. code-block:: python
-
- >>> from qlib.contrib.estimator.fetcher import FileFetcher
- >>> f = FileFetcher(experiments_dir=r'./')
- >>> print(f.get_experiments())
- {
- 'test_experiment': [
- {
- 'id': '1',
- 'config': ...
- },
- {
- 'id': '2',
- 'config': ...
- },
- {
- 'id': '3',
- 'config': ...
- }
- ]
- }
- >>> print(f.get_experiment('test_experiment', '1'))
- risk
- excess_return_without_cost mean 0.000605
- std 0.005481
- annualized_return 0.152373
- information_ratio 1.751319
- max_drawdown -0.059055
- excess_return_with_cost mean 0.000410
- std 0.005478
- annualized_return 0.103265
- information_ratio 1.187411
- max_drawdown -0.075024
-
-
-
-`MongoFetcher`
-^^^^^^^^^^^^^^^
-
-The `FileFetcher` is a subclass of `Fetcher`, which could fetch files from `mongo` observer. Users should initialize the fetcher with `mongo_url`. The following is an example:
-
-.. code-block:: python
-
- >>> from qlib.contrib.estimator.fetcher import MongoFetcher
- >>> f = MongoFetcher(mongo_url=..., db_name=...)
-
diff --git a/docs/component/model.rst b/docs/component/model.rst
index 0cd375a24..e4aa4ca91 100644
--- a/docs/component/model.rst
+++ b/docs/component/model.rst
@@ -1,4 +1,5 @@
.. _model:
+
============================================
Interday Model: Model Training & Prediction
============================================
@@ -6,164 +7,138 @@ Interday Model: Model Training & Prediction
Introduction
===================
-``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``Estimator``, please refer to `Estimator: Workflow Management `_.
+``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management `_.
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Interday Model`` can be used as an independent module also.
Base Class & Interface
======================
-``Qlib`` provides a base class `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ from which all models should inherit.
+``Qlib`` provides a base class `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ from which all models should inherit.
The base class provides the following interfaces:
- `__init__(**kwargs)`
- Initialization.
- - If users use ``Estimator`` to start an `experiment`, the parameter of `__init__` method shoule be consistent with the hyperparameters in the configuration file.
-- `fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs)`
+- `fit(self, dataset, **kwargs)`
- Train model.
- Parameter:
- - `x_train`, pd.DataFrame type, train feature
- The following example explains the value of `x_train`:
+ - `dataset`, ``Qlib``'s ``DatasetH`` type. For more information about ``DatasetH``, users can refer to the related document: `Qlib Dataset <../component/data.html#dataset>`_.
+ The `dataset` is passed into the `model`'s method because there are some unique data preprocessing procedures for each, we want to give each model maximum flexibility to handle the data that is suitable for their own.
+ The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`:
- .. code-block:: YAML
-
- KMID KLEN KMID2 KUP KUP2
- instrument datetime
- SH600004 2012-01-04 0.000000 0.017685 0.000000 0.012862 0.727275
- 2012-01-05 -0.006473 0.025890 -0.250001 0.012945 0.499998
- 2012-01-06 0.008117 0.019481 0.416666 0.008117 0.416666
- 2012-01-09 0.016051 0.025682 0.624998 0.006421 0.250001
- 2012-01-10 0.017323 0.026772 0.647057 0.003150 0.117648
- ... ... ... ... ... ...
- SZ300273 2014-12-25 -0.005295 0.038697 -0.136843 0.016293 0.421052
- 2014-12-26 -0.022486 0.041701 -0.539215 0.002453 0.058824
- 2014-12-29 -0.031526 0.039092 -0.806451 0.000000 0.000000
- 2014-12-30 -0.010000 0.032174 -0.310811 0.013913 0.432433
- 2014-12-31 0.010917 0.020087 0.543479 0.001310 0.065216
+ .. code-block:: Python
-
- `x_train` is a pandas DataFrame, whose index is MultiIndex . Each column of `x_train` corresponds to a feature, and the column name is the feature name.
-
- .. note::
-
- The number and names of the columns are determined by the data handler, please refer to `Data Handler `_ and `Estimator Data Section `_.
-
- - `y_train`, pd.DataFrame type, train label
- The following example explains the value of `y_train`:
+ # get features and labels
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
- .. code-block:: YAML
-
- LABEL
- instrument datetime
- SH600004 2012-01-04 -0.798456
- 2012-01-05 -1.366716
- 2012-01-06 -0.491026
- 2012-01-09 0.296900
- 2012-01-10 0.501426
- ... ...
- SZ300273 2014-12-25 -0.465540
- 2014-12-26 0.233864
- 2014-12-29 0.471368
- 2014-12-30 0.411914
- 2014-12-31 1.342723
-
- `y_train` is a pandas DataFrame, whose index is MultiIndex . The `LABEL` column represents the value of train label.
-
- .. note::
-
- The number and names of the columns are determined by the ``Data Handler``, please refer to `Data Handler `_.
-
- - `x_valid`, pd.DataFrame type, validation feature
- The format of `x_valid` is same as `x_train`
-
-
- - `y_valid`, pd.DataFrame type, validation label
- The format of `y_valid` is same as `y_train`
-
- - `w_train`(Optional args, default is None), pd.DataFrame type, train weight
- `w_train` is a pandas DataFrame, whose shape and index is same as `x_train`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
-
- - `w_train`(Optional args, default is None), pd.DataFrame type, validation weight
- `w_train` is a pandas DataFrame, whose shape and index is the same as `x_valid`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
-
-- `predict(self, x_test, **kwargs)`
- - Predict test data 'x_test'
- - Parameter:
- - `x_test`, pd.DataFrame type, test features
- The form of `x_test` is same as `x_train` in 'fit' method.
- - Return:
- - `label`, np.ndarray type, test label
- The label of `x_test` that predicted by model.
-
-- `score(self, x_test, y_test, w_test=None, **kwargs)`
- - Evaluate model with test feature/label
- - Parameter:
- - `x_test`, pd.DataFrame type, test feature
- The format of `x_test` is same as `x_train` in `fit` method.
+ # get weights
+ try:
+ wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
+ w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
+ except KeyError as e:
+ w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
+ w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
- - `x_test`, pd.DataFrame type, test label
- The format of `y_test` is same as `y_train` in `fit` method.
+- `predict(self, dataset, **kwargs)`
+ - Predict test data.
+ - Parameter:
+ - `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
+ - Returns:
+ - Predic results with type: `pandas.Series`.
- - `w_test`, pd.DataFrame type, test weight
- The format of `w_test` is same as `w_train` in `fit` method.
- - Return: float type, evaluation score
+- `finetune(self, dataset, **kwargs)`
+ - Finetune the model.
+ - Parameter:
+ - `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
-For other interfaces such as `save`, `load`, `finetune`, please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
+
+For other interfaces such as `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
Example
==================
-``Qlib`` provides ``LightGBM`` and ``DNN`` models as the baseline, the following steps show how to run`` LightGBM`` as an independent module.
+``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. The following steps show how to run`` LightGBM`` as an independent module.
- Initialize ``Qlib`` with `qlib.init` first, please refer to `Initialization <../start/initialization.html>`_.
- Run the following code to get the `prediction score` `pred_score`
.. code-block:: Python
- from qlib.contrib.estimator.handler import Alpha158
from qlib.contrib.model.gbdt import LGBModel
+ from qlib.contrib.data.handler import Alpha158
+ from qlib.utils import init_instance_by_config, flatten_dict
+ from qlib.workflow import R
+ from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
- DATA_HANDLER_CONFIG = {
- "dropna_label": True,
- "start_date": "2007-01-01",
- "end_date": "2020-08-01",
- "market": MARKET,
+ market = "csi300"
+ benchmark = "SH000300"
+
+ data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": market,
}
- TRAINER_CONFIG = {
- "train_start_date": "2007-01-01",
- "train_end_date": "2014-12-31",
- "validate_start_date": "2015-01-01",
- "validate_end_date": "2016-12-31",
- "test_start_date": "2017-01-01",
- "test_end_date": "2020-08-01",
+ task = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ "kwargs": {
+ "loss": "mse",
+ "colsample_bytree": 0.8879,
+ "learning_rate": 0.0421,
+ "subsample": 0.8789,
+ "lambda_l1": 205.6999,
+ "lambda_l2": 580.9768,
+ "max_depth": 8,
+ "num_leaves": 210,
+ "num_threads": 20,
+ },
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+ },
}
+
+ # model initiaiton
+ model = init_instance_by_config(task["model"])
+ dataset = init_instance_by_config(task["dataset"])
- x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(
- **DATA_HANDLER_CONFIG
- ).get_split_data(**TRAINER_CONFIG)
+ # start exp
+ with R.start(experiment_name="workflow"):
+ # train
+ R.log_params(**flatten_dict(task))
+ model.fit(dataset)
+ # prediction
+ recorder = R.get_recorder()
+ sr = SignalRecord(model, dataset, recorder)
+ sr.generate()
- MODEL_CONFIG = {
- "loss": "mse",
- "colsample_bytree": 0.8879,
- "learning_rate": 0.0421,
- "subsample": 0.8789,
- "lambda_l1": 205.6999,
- "lambda_l2": 580.9768,
- "max_depth": 8,
- "num_leaves": 210,
- "num_threads": 20,
- }
- # use default model
- model = LGBModel(**MODEL_CONFIG)
- model.fit(x_train, y_train, x_validate, y_validate)
- _pred = model.predict(x_test)
- pred_score = pd.DataFrame(index=_pred.index)
- pred_score["score"] = _pred.iloc(axis=1)[0]
-
- .. note:: `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler `_.
+ .. note::
+
+ `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler `_.
+ `SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow `_.
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
@@ -175,4 +150,4 @@ Qlib supports custom models. If users are interested in customizing their own mo
API
===================
-Please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
+Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
diff --git a/docs/component/recorder.rst b/docs/component/recorder.rst
new file mode 100644
index 000000000..4304dcce5
--- /dev/null
+++ b/docs/component/recorder.rst
@@ -0,0 +1,97 @@
+.. _recorder:
+
+====================================
+Qlib Recorder: Experiment Management
+====================================
+.. currentmodule:: qlib
+
+Introduction
+===================
+``Qlib`` contains an experiment management system named ``QlibRecorder``, which is designed to help users handle experiment and analysis results in an efficient way.
+
+There are three components of the system:
+
+- `ExperimentManager`
+ a class that manages experiments.
+
+- `Experiment`
+ a class of experiment, and each instance of it is responsible for a single experiment.
+
+- `Recorder`
+ a class of recorder, and each instance of it is responsible for a single run.
+
+Here is a general view of the structure of the system:
+
+.. code-block::
+
+ ExperimentManager
+ - Experiment 1
+ - Recorder 1
+ - Recorder 2
+ - ...
+ - Experiment 2
+ - Recorder 1
+ - Recorder 2
+ - ...
+ - ...
+
+Currently, the components of this experiment management system are implemented using the machine learning platform: ``MLFlow`` (`link `_).
+
+
+Qlib Recorder
+===================
+``QlibRecorder`` provides a high level API for users to use the experiment management system. The interfaces are wrapped in the variable ``R`` in ``Qlib``, and users can directly use ``R`` to interact with the system. The following command shows how to import ``R`` in Python:
+
+.. code-block:: Python
+
+ from qlib.workflow import R
+
+``QlibRecorder`` includes several common API for managing `experiments` and `recorders` within a workflow. For more available APIs, please refer to the following section about `Experiment Manager`, `Experiment` and `Recorder`.
+
+Here are the available interfaces of ``QlibRecorder``:
+
+.. autoclass:: qlib.workflow.__init__.QlibRecorder
+ :members:
+
+Experiment Manager
+===================
+
+The ``ExpManager`` module in ``Qlib`` is responsible for managing different experiments. Most of the APIs of ``ExpManager`` are similar to ``QlibRecorder``, and the most important API will be the ``get_exp`` method. User can directly refer to the documents above for some detailed information about how to use the ``get_exp`` method.
+
+.. autoclass:: qlib.workflow.expm.ExpManager
+ :members: get_exp, list_experiments
+
+For other interfaces such as `create_exp`, `delete_exp`, please refer to `Experiment Manager API <../reference/api.html#experiment-manager>`_.
+
+Experiment
+===================
+
+The ``Experiment`` class is solely responsible for a single experiment, and it will handle any operations that are related to an experiment. Basic methods such as `start`, `end` an experiment are included. Besides, methods related to `recorders` are also available: such methods include `get_recorder` and `list_recorders`.
+
+.. autoclass:: qlib.workflow.exp.Experiment
+ :members: get_recorder, list_recorders
+
+For other interfaces such as `search_records`, `delete_recorder`, please refer to `Experiment API <../reference/api.html#experiment>`_.
+
+Recorder
+===================
+
+The ``Recorder`` class is responsible for a single recorder. It will handle some detailed operations such as ``log_metrics``, ``log_params`` of a single run. It is designed to help user to easily track results and things being generated during a run.
+
+Here are some important APIs that are not included in the ``QlibRecorder``:
+
+.. autoclass:: qlib.workflow.recorder.Recorder
+ :members: list_artifacts, list_metrics, list_params, list_tags
+
+For other interfaces such as `save_objects`, `load_object`, please refer to `Recorder API <../reference/api.html#recorder>`_.
+
+Record Template
+===================
+
+The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
+
+- ``SignalRecord``: This class generates the `preidction` results of the model.
+- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
+- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
+
+For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
\ No newline at end of file
diff --git a/docs/component/report.rst b/docs/component/report.rst
index 81ebbd1f3..7d8053c78 100644
--- a/docs/component/report.rst
+++ b/docs/component/report.rst
@@ -1,12 +1,13 @@
.. _report:
+
==========================================
-Aanalysis: Evaluation & Results Analysis
+Analysis: Evaluation & Results Analysis
==========================================
Introduction
===================
-``Aanalysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
+``Analysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
- analysis_position
- report_graph
diff --git a/docs/component/strategy.rst b/docs/component/strategy.rst
index c0ee687ce..0bdf453fe 100644
--- a/docs/component/strategy.rst
+++ b/docs/component/strategy.rst
@@ -1,4 +1,5 @@
.. _strategy:
+
========================================
Interday Strategy: Portfolio Management
========================================
diff --git a/docs/component/workflow.rst b/docs/component/workflow.rst
new file mode 100644
index 000000000..c44f1100f
--- /dev/null
+++ b/docs/component/workflow.rst
@@ -0,0 +1,280 @@
+.. _workflow:
+
+=================================
+Workflow: Workflow Management
+=================================
+.. currentmodule:: qlib
+
+Introduction
+===================
+
+The components in `Qlib Framework <../introduction/introduction.html#framework>`_ are designed in a loosely-coupled way. Users could build their own Quant research workflow with these components like `Example `_.
+
+
+Besides, ``Qlib`` provides more user-friendly interfaces named ``qrun`` to automatically run the whole workflow defined by configuration. A concrete execution of the whole workflow is called an `experiment`.
+With ``qrun``, user can easily run an `experiment`, which includes the following steps:
+
+- Data
+ - Loading
+ - Processing
+ - Slicing
+- Model
+ - Training and inference
+ - Saving & loading
+- Evaluation
+ - Forecast signal analysis
+ - Backtest
+
+For each `experiment`, ``Qlib`` has a complete system to tracking all the information as well as artifacts generated during training, inference and evaluation phase. For more information about how Qlib handles `experiment`, please refer to the related document: `Recorder: Experiment Management <../component/recorder.html>`_.
+
+Complete Example
+===================
+
+Before getting into details, here is a complete example of ``qrun``, which defines the workflow in typical Quant research.
+Below is a typical config file of ``qrun``.
+
+.. code-block:: YAML
+
+ provider_uri: "~/.qlib/qlib_data/cn_data"
+ region: cn
+ market: &market csi300
+ benchmark: &benchmark SH000300
+ data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+ task:
+ model:
+ class: LGBModel
+ module_path: qlib.contrib.model.gbdt
+ kwargs:
+ loss: mse
+ colsample_bytree: 0.8879
+ learning_rate: 0.0421
+ subsample: 0.8789
+ lambda_l1: 205.6999
+ lambda_l2: 580.9768
+ max_depth: 8
+ num_leaves: 210
+ num_threads: 20
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
+
+After saving the config into `configuration.yaml`, users could start the workflow and test their ideas with a single command below.
+
+.. code-block:: bash
+
+ qrun -c configuration.yaml
+
+.. note::
+
+ `qrun` will be placed in your $PATH directory when installing ``Qlib``.
+
+
+Configuration File
+===================
+
+Let's get into details of ``qrun`` in this section.
+
+Before using ``qrun``, users need to prepare a configuration file. The following content shows how to prepare each part of the configuration file.
+
+Qlib Data Section
+--------------------
+
+At first, the configuration file needs to contain several basic parameters about the data, which will be used for qlib initialization, data handling and backtest.
+
+.. code-block:: YAML
+
+ provider_uri: "~/.qlib/qlib_data/cn_data"
+ region: cn
+ market: &market csi300
+ benchmark: &benchmark SH000300
+
+The meaning of each field is as follows:
+
+- `provider_uri`
+ Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
+
+- `region`
+ - If `region` == "us", ``Qlib`` will be initialized in US-stock mode.
+ - If `region` == "cn", ``Qlib`` will be initialized in china-stock mode.
+
+ .. note::
+
+ The value of `region` should be aligned with the data stored in `provider_uri`.
+
+- `market`
+ Type: str. Index name, the default value is `csi500`.
+
+- `benchmark`
+ Type: str, list or pandas.Series. Stock index symbol, the default value is `SH000905`.
+
+ .. note::
+
+ * If `benchmark` is str, it will use the daily change as the 'bench'.
+
+ * If `benchmark` is list, it will use the daily average change of the stock pool in the list as the 'bench'.
+
+ * If `benchmark` is pandas.Series, whose `index` is trading date and the value T is the change from T-1 to T, it will be directly used as the 'bench'. An example is as following:
+
+ .. code-block:: python
+
+ print(D.features(D.instruments('csi500'), ['$close/Ref($close, 1)-1'])['$close/Ref($close, 1)-1'].head())
+ 2017-01-04 0.011693
+ 2017-01-05 0.000721
+ 2017-01-06 -0.004322
+ 2017-01-09 0.006874
+ 2017-01-10 -0.003350
+.. note::
+
+ The symbol `&` in `yaml` file stands for an anchor of a field, which is useful when another fields include this parameter as part of the value. Taking the configuration file above as an example, users can directly change the value of `market` and `benchmark` without traversing the entire configuration file.
+
+Model Section
+--------------------
+
+In the `task` field, the `model` section describes the parameters of the model to be used for training and inference. For more information about the base ``Model`` class, please refer to `Qlib Model <../component/model.html>`_.
+
+.. code-block:: YAML
+
+ model:
+ class: LGBModel
+ module_path: qlib.contrib.model.gbdt
+ kwargs:
+ loss: mse
+ colsample_bytree: 0.8879
+ learning_rate: 0.0421
+ subsample: 0.8789
+ lambda_l1: 205.6999
+ lambda_l2: 580.9768
+ max_depth: 8
+ num_leaves: 210
+ num_threads: 20
+
+The meaning of each field is as follows:
+
+- `class`
+ Type: str. The name for the model class.
+
+- `module_path`
+ Type: str. The path for the model in qlib.
+
+- `kwargs`
+ The keywords arguments for the model. Please refer to the specific model implementation for more information: `models `_.
+
+.. note::
+
+ ``Qlib`` provides a util named: ``init_instance_by_config`` to initialize any class inside ``Qlib`` with the configuration includes the fields: `class`, `module_path` and `kwargs`.
+
+Dataset Section
+--------------------
+
+The `dataset` field describes the parameters for the ``Dataset`` module in ``Qlib`` as well those for the module ``DataHandler``. For more information about the ``Dataset`` module, please refer to `Qlib Model <../component/data.html#dataset>`_.
+
+The keywords arguments configuration of the ``DataHandler`` is as follows:
+
+.. code-block:: YAML
+
+ data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+
+Users can refer to the document of `DataHandler <../component/data.html#datahandler>`_ for more information about the meaning of each field in the configuration.
+
+Here is the configuration for the ``Dataset`` module which will take care of data preprossing and slicing during the training and testing phase.
+
+.. code-block:: YAML
+
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+
+Record Section
+--------------------
+
+The `record` field is about the parameters the ``Record`` module in ``Qlib``. ``Record`` is responsible for generating certain analysis and evaluation results such as `prediction`, `information Coefficient (IC)` and `backtest`.
+
+The following script is the configuration of `backtest` and the `strategy` used in `backtest`:
+
+.. code-block:: YAML
+
+ port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+
+For more information about the meaning of each field in configuration of `strategy` and `backtest`, users can look up the documents: `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
+
+Here is the configuration details of different `Record Template` such as ``SignalRecord`` and ``PortAnaRecord``:
+
+.. code-block:: YAML
+
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
+
+For more information about the ``Record`` module in ``Qlib``, user can refer to the related document: `Record <../component/recorder.html#record-template>`_.
diff --git a/docs/conf.py b/docs/conf.py
index b91efb9a9..5359d08ed 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -124,7 +124,7 @@ html_theme_options = {
"logo_only": True,
"collapse_navigation": False,
"display_version": False,
- "navigation_depth": 3,
+ "navigation_depth": 4,
}
# Add any paths that contain custom static files (such as style sheets) here,
diff --git a/docs/index.rst b/docs/index.rst
index 5bcbbf19b..1e43cf99e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -35,12 +35,13 @@ Document Structure
:maxdepth: 3
:caption: COMPONENTS:
- Estimator: Workflow Management
+ Workflow: Workflow Management
Data Layer: Data Framework&Usage
Interday Model: Model Training & Prediction
Interday Strategy: Portfolio Management
Intraday Trading: Model&Strategy Testing
- Aanalysis: Evaluation & Results Analysis
+ Qlib Recorder: Experiment Management
+ Analysis: Evaluation & Results Analysis
.. toctree::
:maxdepth: 3
@@ -48,6 +49,7 @@ Document Structure
Building Formulaic Alphas
Online & Offline mode
+
.. toctree::
:maxdepth: 3
:caption: REFERENCE:
diff --git a/docs/introduction/introduction.rst b/docs/introduction/introduction.rst
index 3e4d11e28..06fac46fa 100644
--- a/docs/introduction/introduction.rst
+++ b/docs/introduction/introduction.rst
@@ -21,27 +21,27 @@ Framework
At the module level, Qlib is a platform that consists of above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
-====================== ==============================================================================
-Name Description
-====================== ==============================================================================
-`Data layer` `DataServer` focuses on providing high-performance infrastructure for users to
- manage and retrieve raw data. `DataEnhancement` will preprocess the data and
- provide the best dataset to be fed into the models.
-`Interday Model` `Interday model` focuses on producing prediction scores (aka. `alpha`). Models
- are trained by `Model Creator` and managed by `Model Manager`. Users could
- choose one or multiple models for prediction. Multiple models could be combined
- with `Ensemble` module.
-`Interday Strategy` `Portfolio Generator` will take prediction scores as input and output the
- orders based on the current position to achieve the target portfolio.
+======================== ==============================================================================
+Name Description
+======================== ==============================================================================
+`Infrastructure` layer `Infrastructure` layer provides underlying support for Quant research.
+ `DataServer` provides high-performance infrastructure for users to manage
+ and retrieve raw data. `Trainer` provides flexible interface to control
+ the training process of models which enable algorithms controlling the
+ training process.
-`Intraday Trading` `Order Executor` is responsible for executing orders output by
- `Interday Strategy` and returning the executed results.
+`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
+ `Information Extractor` extracts data for models. `Forecast Model` focuses
+ on producing all kinds of forecast signals (e.g. _alpha_, risk) for other
+ modules. With these signals `Portfolio Generator` will generate the target
+ portfolio and produce orders to be executed by `Order Executor`.
-`Analysis` Users could get a detailed analysis report of forecasting signals and portfolios
- in this part.
-====================== ==============================================================================
+`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
+ system. `Analyser` module will provide users detailed analysis reports of
+ forecasting signals, portfolios and execution results
+======================== ==============================================================================
- The modules with hand-drawn style are under development and will be released in the future.
- The modules with dashed borders are highly user-customizable and extendible.
diff --git a/docs/introduction/quick.rst b/docs/introduction/quick.rst
index 9fff8cb3f..f228ce2af 100644
--- a/docs/introduction/quick.rst
+++ b/docs/introduction/quick.rst
@@ -49,18 +49,19 @@ To kown more about `prepare data`, please refer to `Data Preparation <../compone
Auto Quant Research Workflow
====================================
-``Qlib`` provides a tool named ``Estimator`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
+``Qlib`` provides a tool named ``qrun`` to run the whole workflow automatically (including building dataset, training models, backtest and evaluation). Users can start an auto quant research workflow and have a graphical reports analysis according to the following steps:
- Quant Research Workflow:
- - Run ``Estimator`` with `estimator_config.yaml` as following.
+ - Run ``qrun`` with a config file of the LightGBM model `workflow_config_lightgbm.yaml` as following.
+
.. code-block::
cd examples # Avoid running program under the directory contains `qlib`
- estimator -c estimator/estimator_config.yaml
+ qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
- - Estimator result
- The result of ``Estimator`` is as follows, which is also the result of ``Intraday Trading``. Please refer to `Intraday Trading <../component/backtest.html>`_. for more details about the result.
+ - Workflow result
+ The result of ``qrun`` is as follows, which is also the typical result of ``Forecast model(alpha)``. Please refer to `Intraday Trading <../component/backtest.html>`_. for more details about the result.
.. code-block:: python
@@ -77,17 +78,17 @@ Auto Quant Research Workflow
max_drawdown -0.075024
- To know more about `Estimator`, please refer to `Estimator: Workflow Management <../component/estimator.html>`_.
+ To know more about `workflow` and `qrun`, please refer to `Workflow: Workflow Management <../component/workflow.html>`_.
- Graphical Reports Analysis:
- - Run ``examples/estimator/analyze_from_estimator.ipynb`` with jupyter notebook
- Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/estimator/analyze_from_estimator.ipynb``.
+ - Run ``examples/workflow_by_code.ipynb`` with jupyter notebook
+ Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/workflow_by_code.ipynb``.
- Graphical Reports
- Users can get graphical reports about the analysis, please refer to `Aanalysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
+ Users can get graphical reports about the analysis, please refer to `Analysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
Custom Model Integration
===============================================
-``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. If users are interested in the custom model, please refer to `Custom Model Integration <../start/integration.html>`_.
+``Qlib`` provides a batch of models (such as ``lightGBM`` and ``MLP`` models) as examples of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``. If users are interested in the custom model, please refer to `Custom Model Integration <../start/integration.html>`_.
diff --git a/docs/reference/api.rst b/docs/reference/api.rst
index ea1a545e2..f21a9f518 100644
--- a/docs/reference/api.rst
+++ b/docs/reference/api.rst
@@ -23,16 +23,13 @@ Filter
.. automodule:: qlib.data.filter
:members:
-Feature
---------------------
-
Class
-~~~~~~~~~~~~~~~~~~~~
+--------------------
.. automodule:: qlib.data.base
:members:
Operator
-~~~~~~~~~~~~~~~~~~~~
+--------------------
.. automodule:: qlib.data.ops
:members:
@@ -56,19 +53,36 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:
+Dataset
+---------------
+
+Dataset Class
+~~~~~~~~~~~~~~~~~~~~
+.. automodule:: qlib.data.dataset.__init__
+ :members:
+
+Data Loader
+~~~~~~~~~~~~~~~~~~~~
+.. automodule:: qlib.data.dataset.loader
+ :members:
+
+Data Handler
+~~~~~~~~~~~~~~~~~~~~
+.. automodule:: qlib.data.dataset.handler
+ :members:
+
+Processor
+~~~~~~~~~~~~~~~~~~~~
+.. automodule:: qlib.data.dataset.processor
+ :members:
+
Contrib
====================
-
-Data Handler
----------------
-.. automodule:: qlib.contrib.estimator.handler
- :members:
-
Model
--------------------
-.. automodule:: qlib.contrib.model.base
+.. automodule:: qlib.model.base
:members:
Strategy
@@ -116,3 +130,26 @@ Report
:members:
+Workflow
+====================
+
+
+Experiment Manager
+--------------------
+.. autoclass:: qlib.workflow.expm.ExpManager
+ :members:
+
+Experiment
+--------------------
+.. autoclass:: qlib.workflow.exp.Experiment
+ :members:
+
+Recorder
+--------------------
+.. autoclass:: qlib.workflow.recorder.Recorder
+ :members:
+
+Record Template
+--------------------
+.. automodule:: qlib.workflow.record_temp
+ :members:
\ No newline at end of file
diff --git a/docs/start/getdata.rst b/docs/start/getdata.rst
index b352082cb..8e1695c14 100644
--- a/docs/start/getdata.rst
+++ b/docs/start/getdata.rst
@@ -1,4 +1,5 @@
.. _getdata:
+
=============================
Data Retrieval
=============================
diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst
index 992307515..05a329df7 100644
--- a/docs/start/initialization.rst
+++ b/docs/start/initialization.rst
@@ -1,4 +1,5 @@
.. _initialization:
+
====================
Qlib Initialization
====================
@@ -11,14 +12,16 @@ Initialization
Please follow the steps below to initialize ``Qlib``.
-- Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance `_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>` for more information about customized dataset.
+Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance `_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>`_ for more information about customized dataset.
+
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
- Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
+
+Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
-- Initialize Qlib before calling other APIs: run following code in python.
+Initialize Qlib before calling other APIs: run following code in python.
.. code-block:: Python
@@ -58,3 +61,16 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
.. note::
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
+- `exp_manager`
+ Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.
+ .. code-block:: Python
+
+ # For example, if you want to set your tracking_uri to a , you can initialize qlib below
+ qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager= {
+ "class": "MLflowExpManager",
+ "module_path": "qlib.workflow.expm",
+ "kwargs": {
+ "uri": "python_execution_path/mlruns",
+ "default_exp_name": "Experiment",
+ }
+ })
diff --git a/docs/start/installation.rst b/docs/start/installation.rst
index 2ac3dda77..af0b37372 100644
--- a/docs/start/installation.rst
+++ b/docs/start/installation.rst
@@ -1,4 +1,5 @@
.. _installation:
+
====================
Installation
====================
diff --git a/docs/start/integration.rst b/docs/start/integration.rst
index 2732f61df..e36805c01 100644
--- a/docs/start/integration.rst
+++ b/docs/start/integration.rst
@@ -5,17 +5,17 @@ Custom Model Integration
Introduction
===================
-``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``.
+``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``MLP``, ``LSTM``, etc.. These models are examples of ``Interday Model``. In addition to the default models ``Qlib`` provide, users can integrate their own custom models into ``Qlib``.
Users can integrate their own custom models according to the following steps.
-- Define a custom model class, which should be a subclass of the `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_.
+- Define a custom model class, which should be a subclass of the `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_.
- Write a configuration file that describes the path and parameters of the custom model.
- Test the custom model.
Custom Model Class
===========================
-The Custom models need to inherit `qlib.contrib.model.base.Model <../reference/api.html#module-qlib.contrib.model.base>`_ and override the methods in it.
+The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#module-qlib.model.base>`_ and override the methods in it.
- Override the `__init__` method
- ``Qlib`` passes the initialized parameters to the \_\_init\_\_ method.
@@ -32,79 +32,77 @@ The Custom models need to inherit `qlib.contrib.model.base.Model <../reference/a
- Override the `fit` method
- ``Qlib`` calls the fit method to train the model
- - The parameters must include training feature `x_train`, training label `y_train`, test feature `x_valid`, test label `y_valid` at least.
- - The parameters could include some optional parameters with default values, such as train weight `w_train`, test weight `w_valid` and `num_boost_round = 1000`.
+ - The parameters must include training feature `dataset`.
+ - The parameters could include some optional parameters with default values, such as `num_boost_round = 1000` for `GBDT`.
- Code Example: In the following example, `num_boost_round = 1000` is an optional parameter.
.. code-block:: Python
- def fit(self, x_train:pd.DataFrame, y_train:pd.DataFrame, x_valid:pd.DataFrame, y_valid:pd.DataFrame,
- w_train:pd.DataFrame = None, w_valid:pd.DataFrame = None, num_boost_round = 1000, **kwargs):
+ def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
+
+ # prepare dataset for lgb training and evaluation
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
- y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
+ y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
- raise ValueError('LightGBM doesn\'t support multi-label training')
+ raise ValueError("LightGBM doesn't support multi-label training")
- w_train_weight = None if w_train is None else w_train.values
- w_valid_weight = None if w_valid is None else w_valid.values
+ dtrain = lgb.Dataset(x_train.values, label=y_train)
+ dvalid = lgb.Dataset(x_valid.values, label=y_valid)
- dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
- dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
- self._model = lgb.train(
- self._params,
- dtrain,
+ # fit the model
+ self.model = lgb.train(
+ self.params,
+ dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
- valid_names=['train', 'valid'],
+ valid_names=["train", "valid"],
+ early_stopping_rounds=early_stopping_rounds,
+ verbose_eval=verbose_eval,
+ evals_result=evals_result,
**kwargs
)
- Override the `predict` method
- - The parameters include the test features.
+ - The parameters must include training feature `dataset`, which will be userd to get the test dataset.
- Return the `prediction score`.
- - Please refer to `Model API <../reference/api.html#module-qlib.contrib.model.base>`_ for the parameter types of the fit method.
- - Code Example: In the following example, users need to use dnn to predict the label(such as `preds`) of test data `x_test` and return it.
+ - Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
+ - Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
.. code-block:: Python
- def predict(self, x_test:pd.DataFrame, **kwargs)-> numpy.ndarray:
- if self._model is None:
- raise ValueError('model is not fitted yet!')
- return self._model.predict(x_test.values)
+ def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:
+ if self.model is None:
+ raise ValueError("model is not fitted yet!")
+ x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
+ return pd.Series(self.model.predict(x_test.values), index=x_test.index)
-- Override the `save` method & `load` method
- - The `save` method parameter includes the a `filename` that represents an absolute path, user need to save model into the path.
- - The `load` method parameter includes the a `buffer` read from the `filename` passed in the `save` method, users need to load model from the `buffer`.
- - Code Example:
+- Override the `finetune` method
+ - The parameters must include training feature `dataset`.
+ - Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
.. code-block:: Python
- def save(self, filename):
- if self._model is None:
- raise ValueError('model is not fitted yet!')
- self._model.save_model(filename)
-
- def load(self, buffer):
- self._model = lgb.Booster(params={'model_str': buffer.decode('utf-8')})
-
-.. Without tuner, this part will not be used
-.. - Override the `score` method(This step is optional)
-.. - The parameters include the test features and test labels.
-.. - Return the evaluation score of the model. It's recommended to adopt the loss between labels and `prediction score`.
-.. - Code Example: In the following example, users need to calculate the weighted loss with test data `x_test`, test label `y_test` and the weight `w_test`.
-.. .. code-block:: Python
-..
-.. def score(self, x_test:pd.Dataframe, y_test:pd.Dataframe, w_test:pd.DataFrame = None) -> float:
-.. # Remove rows from x, y and w, which contain Nan in any columns in y_test.
-.. x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
-.. preds = self.predict(x_test)
-.. w_test_weight = None if w_test is None else w_test.values
-.. scorer = mean_squared_error if self.loss_type == 'mse' else roc_auc_score
-.. return scorer(y_test.values, preds, sample_weight=w_test_weight)
+ def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
+ # Based on existing model and finetune by train more rounds
+ dtrain, _ = self._prepare_data(dataset)
+ self.model = lgb.train(
+ self.params,
+ dtrain,
+ num_boost_round=num_boost_round,
+ init_model=self.model,
+ valid_sets=[dtrain],
+ valid_names=["train"],
+ verbose_eval=verbose_eval,
+ )
Configuration File
=======================
-The configuration file is described in detail in the `estimator <../component/estimator.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file.
+The configuration file is described in detail in the `Workflow <../component/workflow.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file. The configuration describes which models to use and how we can initialize it.
- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`.
@@ -124,23 +122,23 @@ The configuration file is described in detail in the `estimator <../component/es
num_leaves: 210
num_threads: 20
-Users could find configuration file of the baseline of the ``Model`` in ``qlib/examples/estimator/estimator_config.yaml`` and ``qlib/examples/estimator/estimator_config_dnn.yaml``
+Users could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder.
Model Testing
=====================
-Assuming that the configuration file is ``examples/estimator/estimator_config.yaml``, users can run the following command to test the custom model:
+Assuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml``, users can run the following command to test the custom model:
.. code-block:: bash
cd examples # Avoid running program under the directory contains `qlib`
- estimator -c estimator/estimator_config.yaml
+ qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
-.. note:: ``estimator`` is a built-in command of ``Qlib``.
+.. note:: ``qrun`` is a built-in command of ``Qlib``.
-Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/train_backtest_analyze.ipynb``.
+Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``.
Reference
=====================
-To know more about ``Interday Model``, please refer to `Interday Model: Model Training & Prediction <../component/model.html>`_ and `Model API <../reference/api.html#module-qlib.contrib.model.base>`_.
+To know more about ``Interday Model``, please refer to `Interday Model: Model Training & Prediction <../component/model.html>`_ and `Model API <../reference/api.html#module-qlib.model.base>`_.
diff --git a/examples/benchmarks/ALSTM/README.md b/examples/benchmarks/ALSTM/README.md
new file mode 100644
index 000000000..1b749bd80
--- /dev/null
+++ b/examples/benchmarks/ALSTM/README.md
@@ -0,0 +1,8 @@
+# ALSTM
+
+- ALSTM contains a temporal attentive aggregation layer based on normal LSTM.
+
+- Paper: A dual-stage attention-based recurrent neural network for time series prediction.
+
+ [https://www.ijcai.org/Proceedings/2017/0366.pdf](https://www.ijcai.org/Proceedings/2017/0366.pdf)
+
diff --git a/examples/benchmarks/ALSTM/requirements.txt b/examples/benchmarks/ALSTM/requirements.txt
new file mode 100644
index 000000000..1fc2779c0
--- /dev/null
+++ b/examples/benchmarks/ALSTM/requirements.txt
@@ -0,0 +1,4 @@
+numpy==1.17.4
+pandas==1.1.2
+scikit_learn==0.23.2
+torch==1.7.0
diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm.yaml
new file mode 100644
index 000000000..dd57761f3
--- /dev/null
+++ b/examples/benchmarks/ALSTM/workflow_config_alstm.yaml
@@ -0,0 +1,83 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: ALSTM
+ module_path: qlib.contrib.model.pytorch_alstm
+ kwargs:
+ d_feat: 6
+ hidden_size: 64
+ num_layers: 2
+ dropout: 0.0
+ n_epochs: 200
+ lr: 1e-3
+ early_stop: 20
+ batch_size: 800
+ metric: loss
+ loss: mse
+ seed: 0
+ GPU: 0
+ rnn_type: GRU
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: ALPHA360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/CatBoost/README.md b/examples/benchmarks/CatBoost/README.md
new file mode 100644
index 000000000..5e4f3966f
--- /dev/null
+++ b/examples/benchmarks/CatBoost/README.md
@@ -0,0 +1,3 @@
+# CatBoost
+* Code: [https://github.com/catboost/catboost](https://github.com/catboost/catboost)
+* Paper: CatBoost: unbiased boosting with categorical features. [https://proceedings.neurips.cc/paper/2018/file/14491b756b3a51daac41c24863285549-Paper.pdf](https://proceedings.neurips.cc/paper/2018/file/14491b756b3a51daac41c24863285549-Paper.pdf).
\ No newline at end of file
diff --git a/examples/benchmarks/CatBoost/requirements.txt b/examples/benchmarks/CatBoost/requirements.txt
new file mode 100644
index 000000000..507a65944
--- /dev/null
+++ b/examples/benchmarks/CatBoost/requirements.txt
@@ -0,0 +1,3 @@
+pandas==1.1.2
+numpy==1.17.4
+catboost==0.24.3
diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost.yaml
new file mode 100644
index 000000000..9c15dc25b
--- /dev/null
+++ b/examples/benchmarks/CatBoost/workflow_config_catboost.yaml
@@ -0,0 +1,64 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: CatBoostModel
+ module_path: qlib.contrib.model.catboost_model
+ kwargs:
+ loss: RMSE
+ learning_rate: 0.0421
+ subsample: 0.8789
+ max_depth: 6
+ num_leaves: 100
+ thread_count: 20
+ grow_policy: Lossguide
+ bootstrap_type: Poisson
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/GATs/README.md b/examples/benchmarks/GATs/README.md
new file mode 100644
index 000000000..f432b6c5b
--- /dev/null
+++ b/examples/benchmarks/GATs/README.md
@@ -0,0 +1,5 @@
+# GATs
+* Graph Attention Networks(GATs) leverage masked self-attentional layers on graph-structured data. The nodes in stacked layers have different weights and they are able to attend over their
+neighborhoods’ features, without requiring any kind of costly matrix operation (such as inversion) or depending on knowing the graph structure upfront.
+* This code used in Qlib is implemented with PyTorch by ourselves.
+* Paper: Graph Attention Networks https://arxiv.org/pdf/1710.10903.pdf
\ No newline at end of file
diff --git a/examples/benchmarks/GATs/requirements.txt b/examples/benchmarks/GATs/requirements.txt
new file mode 100644
index 000000000..16de0a438
--- /dev/null
+++ b/examples/benchmarks/GATs/requirements.txt
@@ -0,0 +1,4 @@
+pandas==1.1.2
+numpy==1.17.4
+scikit_learn==0.23.2
+torch==1.7.0
diff --git a/examples/benchmarks/GATs/workflow_config_gats.yaml b/examples/benchmarks/GATs/workflow_config_gats.yaml
new file mode 100644
index 000000000..c38b4b312
--- /dev/null
+++ b/examples/benchmarks/GATs/workflow_config_gats.yaml
@@ -0,0 +1,77 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: GATs
+ module_path: qlib.contrib.model.pytorch_gats
+ kwargs:
+ d_feat: 6
+ hidden_size: 64
+ num_layers: 2
+ dropout: 0.7
+ n_epochs: 200
+ lr: 1e-4
+ early_stop: 20
+ metric: loss
+ loss: mse
+ base_model: LSTM
+ seed: 0
+ GPU: 0
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: ALPHA360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/GRU/model_gru_csi300.pkl b/examples/benchmarks/GRU/model_gru_csi300.pkl
new file mode 100644
index 000000000..46347ce8c
Binary files /dev/null and b/examples/benchmarks/GRU/model_gru_csi300.pkl differ
diff --git a/examples/benchmarks/GRU/requirements.txt b/examples/benchmarks/GRU/requirements.txt
new file mode 100644
index 000000000..1fc2779c0
--- /dev/null
+++ b/examples/benchmarks/GRU/requirements.txt
@@ -0,0 +1,4 @@
+numpy==1.17.4
+pandas==1.1.2
+scikit_learn==0.23.2
+torch==1.7.0
diff --git a/examples/benchmarks/GRU/workflow_config_gru.yaml b/examples/benchmarks/GRU/workflow_config_gru.yaml
new file mode 100644
index 000000000..bdfcd4e55
--- /dev/null
+++ b/examples/benchmarks/GRU/workflow_config_gru.yaml
@@ -0,0 +1,82 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: GRU
+ module_path: qlib.contrib.model.pytorch_gru
+ kwargs:
+ d_feat: 6
+ hidden_size: 64
+ num_layers: 2
+ dropout: 0.0
+ n_epochs: 200
+ lr: 1e-3
+ early_stop: 20
+ batch_size: 800
+ metric: loss
+ loss: mse
+ seed: 0
+ GPU: 0
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: ALPHA360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/LSTM/model_lstm_csi300.pkl b/examples/benchmarks/LSTM/model_lstm_csi300.pkl
new file mode 100644
index 000000000..84d6419da
Binary files /dev/null and b/examples/benchmarks/LSTM/model_lstm_csi300.pkl differ
diff --git a/examples/benchmarks/LSTM/requirements.txt b/examples/benchmarks/LSTM/requirements.txt
new file mode 100644
index 000000000..1fc2779c0
--- /dev/null
+++ b/examples/benchmarks/LSTM/requirements.txt
@@ -0,0 +1,4 @@
+numpy==1.17.4
+pandas==1.1.2
+scikit_learn==0.23.2
+torch==1.7.0
diff --git a/examples/benchmarks/LSTM/workflow_config_lstm.yaml b/examples/benchmarks/LSTM/workflow_config_lstm.yaml
new file mode 100644
index 000000000..6512a0df3
--- /dev/null
+++ b/examples/benchmarks/LSTM/workflow_config_lstm.yaml
@@ -0,0 +1,82 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: LSTM
+ module_path: qlib.contrib.model.pytorch_lstm
+ kwargs:
+ d_feat: 6
+ hidden_size: 64
+ num_layers: 2
+ dropout: 0.0
+ n_epochs: 200
+ lr: 1e-3
+ early_stop: 20
+ batch_size: 800
+ metric: loss
+ loss: mse
+ seed: 0
+ GPU: 0
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: ALPHA360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/LightGBM/README.md b/examples/benchmarks/LightGBM/README.md
new file mode 100644
index 000000000..13f408d5f
--- /dev/null
+++ b/examples/benchmarks/LightGBM/README.md
@@ -0,0 +1,4 @@
+# LightGBM
+* Code: [https://github.com/microsoft/LightGBM](https://github.com/microsoft/LightGBM)
+* Paper: LightGBM: A Highly Efficient Gradient Boosting
+Decision Tree. [https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf](https://proceedings.neurips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf).
\ No newline at end of file
diff --git a/examples/benchmarks/LightGBM/requirements.txt b/examples/benchmarks/LightGBM/requirements.txt
new file mode 100644
index 000000000..507d2d453
--- /dev/null
+++ b/examples/benchmarks/LightGBM/requirements.txt
@@ -0,0 +1,3 @@
+pandas==1.1.2
+numpy==1.17.4
+lightgbm==3.1.0
diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
new file mode 100644
index 000000000..790fc3ae5
--- /dev/null
+++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml
@@ -0,0 +1,65 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: LGBModel
+ module_path: qlib.contrib.model.gbdt
+ kwargs:
+ loss: mse
+ colsample_bytree: 0.8879
+ learning_rate: 0.0421
+ subsample: 0.8789
+ lambda_l1: 205.6999
+ lambda_l2: 580.9768
+ max_depth: 8
+ num_leaves: 210
+ num_threads: 20
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/Linear/requirements.txt b/examples/benchmarks/Linear/requirements.txt
new file mode 100644
index 000000000..6a53211f9
--- /dev/null
+++ b/examples/benchmarks/Linear/requirements.txt
@@ -0,0 +1,3 @@
+numpy>=1.17.4
+pandas>=1.0.1
+scikit-learn>=0.23.1
diff --git a/examples/benchmarks/Linear/workflow_config_linear.yaml b/examples/benchmarks/Linear/workflow_config_linear.yaml
new file mode 100644
index 000000000..70d3eaf68
--- /dev/null
+++ b/examples/benchmarks/Linear/workflow_config_linear.yaml
@@ -0,0 +1,71 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: LinearModel
+ module_path: qlib.contrib.model.linear
+ kwargs:
+ estimator: ols
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: True
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/MLP/requirements.txt b/examples/benchmarks/MLP/requirements.txt
new file mode 100644
index 000000000..16de0a438
--- /dev/null
+++ b/examples/benchmarks/MLP/requirements.txt
@@ -0,0 +1,4 @@
+pandas==1.1.2
+numpy==1.17.4
+scikit_learn==0.23.2
+torch==1.7.0
diff --git a/examples/benchmarks/MLP/workflow_config_mlp.yaml b/examples/benchmarks/MLP/workflow_config_mlp.yaml
new file mode 100644
index 000000000..e01c4eb3a
--- /dev/null
+++ b/examples/benchmarks/MLP/workflow_config_mlp.yaml
@@ -0,0 +1,93 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors: [
+ {
+ "class" : "DropCol",
+ "kwargs":{"col_list": ["VWAP0"]}
+ },
+ {
+ "class" : "CSZFillna",
+ "kwargs":{"fields_group": "feature"}
+ }
+ ]
+ learn_processors: [
+ {
+ "class" : "DropCol",
+ "kwargs":{"col_list": ["VWAP0"]}
+ },
+ {
+ "class" : "DropnaProcessor",
+ "kwargs":{"fields_group": "feature"}
+ },
+ "DropnaLabel",
+ {
+ "class": "CSZScoreNorm",
+ "kwargs": {"fields_group": "label"}
+ }
+ ]
+ process_type: "independent"
+
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: DNNModelPytorch
+ module_path: qlib.contrib.model.pytorch_nn
+ kwargs:
+ loss: mse
+ input_dim: 157
+ output_dim: 1
+ lr: 0.002
+ lr_decay: 0.96
+ lr_decay_steps: 100
+ optimizer: adam
+ max_steps: 8000
+ batch_size: 4096
+ GPU: 0
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/SFM/README.md b/examples/benchmarks/SFM/README.md
new file mode 100644
index 000000000..5f74c15d2
--- /dev/null
+++ b/examples/benchmarks/SFM/README.md
@@ -0,0 +1,3 @@
+# State-Frequency-Memory
+- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
+- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. [https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.](https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.)
\ No newline at end of file
diff --git a/examples/benchmarks/SFM/requirements.txt b/examples/benchmarks/SFM/requirements.txt
new file mode 100644
index 000000000..6a3d13097
--- /dev/null
+++ b/examples/benchmarks/SFM/requirements.txt
@@ -0,0 +1,4 @@
+pandas==1.1.2
+numpy==1.17.4
+scikit_learn==0.23.2
+torch==1.7.0
\ No newline at end of file
diff --git a/examples/benchmarks/SFM/workflow_config_sfm.yaml b/examples/benchmarks/SFM/workflow_config_sfm.yaml
new file mode 100644
index 000000000..3fa3f932c
--- /dev/null
+++ b/examples/benchmarks/SFM/workflow_config_sfm.yaml
@@ -0,0 +1,85 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: SFM
+ module_path: qlib.contrib.model.pytorch_sfm
+ kwargs:
+ d_feat: 6
+ hidden_size: 64
+ output_dim: 32
+ freq_dim: 25
+ dropout_W: 0.5
+ dropout_U: 0.5
+ n_epochs: 20
+ lr: 1e-3
+ batch_size: 1600
+ early_stop: 20
+ eval_steps: 5
+ loss: mse
+ optimizer: adam
+ GPU: 1
+ seed: 710
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: ALPHA360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/TFT/README.md b/examples/benchmarks/TFT/README.md
new file mode 100644
index 000000000..5a6a9f153
--- /dev/null
+++ b/examples/benchmarks/TFT/README.md
@@ -0,0 +1,14 @@
+# Temporal Fusion Transformers Benchmark
+## Source
+**Reference**: Lim, Bryan, et al. "Temporal fusion transformers for interpretable multi-horizon time series forecasting." arXiv preprint arXiv:1912.09363 (2019).
+
+**GitHub**: https://github.com/google-research/google-research/tree/master/tft
+
+## Run the Workflow
+Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
+
+### Notes
+1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
+2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
+3. The model must run in GPU, or an error will be raised.
+4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
diff --git a/examples/benchmarks/TFT/data_formatters/__init__.py b/examples/benchmarks/TFT/data_formatters/__init__.py
new file mode 100644
index 000000000..87ec3284f
--- /dev/null
+++ b/examples/benchmarks/TFT/data_formatters/__init__.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/examples/benchmarks/TFT/data_formatters/base.py b/examples/benchmarks/TFT/data_formatters/base.py
new file mode 100644
index 000000000..c68a192ba
--- /dev/null
+++ b/examples/benchmarks/TFT/data_formatters/base.py
@@ -0,0 +1,223 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Default data formatting functions for experiments.
+
+For new datasets, inherit form GenericDataFormatter and implement
+all abstract functions.
+
+These dataset-specific methods:
+1) Define the column and input types for tabular dataframes used by model
+2) Perform the necessary input feature engineering & normalisation steps
+3) Reverts the normalisation for predictions
+4) Are responsible for train, validation and test splits
+
+
+"""
+
+import abc
+import enum
+
+
+# Type defintions
+class DataTypes(enum.IntEnum):
+ """Defines numerical types of each column."""
+
+ REAL_VALUED = 0
+ CATEGORICAL = 1
+ DATE = 2
+
+
+class InputTypes(enum.IntEnum):
+ """Defines input types of each column."""
+
+ TARGET = 0
+ OBSERVED_INPUT = 1
+ KNOWN_INPUT = 2
+ STATIC_INPUT = 3
+ ID = 4 # Single column used as an entity identifier
+ TIME = 5 # Single column exclusively used as a time index
+
+
+class GenericDataFormatter(abc.ABC):
+ """Abstract base class for all data formatters.
+
+ User can implement the abstract methods below to perform dataset-specific
+ manipulations.
+
+ """
+
+ @abc.abstractmethod
+ def set_scalers(self, df):
+ """Calibrates scalers using the data supplied."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def transform_inputs(self, df):
+ """Performs feature transformation."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def format_predictions(self, df):
+ """Reverts any normalisation to give predictions in original scale."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def split_data(self, df):
+ """Performs the default train, validation and test splits."""
+ raise NotImplementedError()
+
+ @property
+ @abc.abstractmethod
+ def _column_definition(self):
+ """Defines order, input type and data type of each column."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_fixed_params(self):
+ """Defines the fixed parameters used by the model for training.
+
+ Requires the following keys:
+ 'total_time_steps': Defines the total number of time steps used by TFT
+ 'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
+ 'num_epochs': Maximum number of epochs for training
+ 'early_stopping_patience': Early stopping param for keras
+ 'multiprocessing_workers': # of cpus for data processing
+
+
+ Returns:
+ A dictionary of fixed parameters, e.g.:
+
+ fixed_params = {
+ 'total_time_steps': 252 + 5,
+ 'num_encoder_steps': 252,
+ 'num_epochs': 100,
+ 'early_stopping_patience': 5,
+ 'multiprocessing_workers': 5,
+ }
+ """
+ raise NotImplementedError
+
+ # Shared functions across data-formatters
+ @property
+ def num_classes_per_cat_input(self):
+ """Returns number of categories per relevant input.
+
+ This is seqeuently required for keras embedding layers.
+ """
+ return self._num_classes_per_cat_input
+
+ def get_num_samples_for_calibration(self):
+ """Gets the default number of training and validation samples.
+
+ Use to sub-sample the data for network calibration and a value of -1 uses
+ all available samples.
+
+ Returns:
+ Tuple of (training samples, validation samples)
+ """
+ return -1, -1
+
+ def get_column_definition(self):
+ """"Returns formatted column definition in order expected by the TFT."""
+
+ column_definition = self._column_definition
+
+ # Sanity checks first.
+ # Ensure only one ID and time column exist
+ def _check_single_column(input_type):
+
+ length = len([tup for tup in column_definition if tup[2] == input_type])
+
+ if length != 1:
+ raise ValueError("Illegal number of inputs ({}) of type {}".format(length, input_type))
+
+ _check_single_column(InputTypes.ID)
+ _check_single_column(InputTypes.TIME)
+
+ identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
+ time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
+ real_inputs = [
+ tup
+ for tup in column_definition
+ if tup[1] == DataTypes.REAL_VALUED and tup[2] not in {InputTypes.ID, InputTypes.TIME}
+ ]
+ categorical_inputs = [
+ tup
+ for tup in column_definition
+ if tup[1] == DataTypes.CATEGORICAL and tup[2] not in {InputTypes.ID, InputTypes.TIME}
+ ]
+
+ return identifier + time + real_inputs + categorical_inputs
+
+ def _get_input_columns(self):
+ """Returns names of all input columns."""
+ return [tup[0] for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
+
+ def _get_tft_input_indices(self):
+ """Returns the relevant indexes and input sizes required by TFT."""
+
+ # Functions
+ def _extract_tuples_from_data_type(data_type, defn):
+ return [tup for tup in defn if tup[1] == data_type and tup[2] not in {InputTypes.ID, InputTypes.TIME}]
+
+ def _get_locations(input_types, defn):
+ return [i for i, tup in enumerate(defn) if tup[2] in input_types]
+
+ # Start extraction
+ column_definition = [
+ tup for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}
+ ]
+
+ categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, column_definition)
+ real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, column_definition)
+
+ locations = {
+ "input_size": len(self._get_input_columns()),
+ "output_size": len(_get_locations({InputTypes.TARGET}, column_definition)),
+ "category_counts": self.num_classes_per_cat_input,
+ "input_obs_loc": _get_locations({InputTypes.TARGET}, column_definition),
+ "static_input_loc": _get_locations({InputTypes.STATIC_INPUT}, column_definition),
+ "known_regular_inputs": _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs),
+ "known_categorical_inputs": _get_locations(
+ {InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, categorical_inputs
+ ),
+ }
+
+ return locations
+
+ def get_experiment_params(self):
+ """Returns fixed model parameters for experiments."""
+
+ required_keys = [
+ "total_time_steps",
+ "num_encoder_steps",
+ "num_epochs",
+ "early_stopping_patience",
+ "multiprocessing_workers",
+ ]
+
+ fixed_params = self.get_fixed_params()
+
+ for k in required_keys:
+ if k not in fixed_params:
+ raise ValueError("Field {}".format(k) + " missing from fixed parameter definitions!")
+
+ fixed_params["column_definition"] = self.get_column_definition()
+
+ fixed_params.update(self._get_tft_input_indices())
+
+ return fixed_params
diff --git a/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py b/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py
new file mode 100644
index 000000000..44a9284f7
--- /dev/null
+++ b/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py
@@ -0,0 +1,219 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Custom formatting functions for Alpha158 dataset.
+
+Defines dataset specific column definitions and data transformations.
+"""
+
+import data_formatters.base
+import libs.utils as utils
+import sklearn.preprocessing
+
+GenericDataFormatter = data_formatters.base.GenericDataFormatter
+DataTypes = data_formatters.base.DataTypes
+InputTypes = data_formatters.base.InputTypes
+
+
+class Alpha158Formatter(GenericDataFormatter):
+ """Defines and formats data for the Alpha158 dataset.
+
+ Attributes:
+ column_definition: Defines input and data type of column used in the
+ experiment.
+ identifiers: Entity identifiers used in experiments.
+ """
+
+ _column_definition = [
+ ("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
+ ("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
+ ("date", DataTypes.DATE, InputTypes.TIME),
+ ("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
+ ("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
+ # Selected 10 features
+ ("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
+ ]
+
+ def __init__(self):
+ """Initialises formatter."""
+
+ self.identifiers = None
+ self._real_scalers = None
+ self._cat_scalers = None
+ self._target_scaler = None
+ self._num_classes_per_cat_input = None
+
+ def split_data(self, df, valid_boundary=2016, test_boundary=2018):
+ """Splits data frame into training-validation-test data frames.
+
+ This also calibrates scaling object, and transforms data for each split.
+
+ Args:
+ df: Source data frame to split.
+ valid_boundary: Starting year for validation data
+ test_boundary: Starting year for test data
+
+ Returns:
+ Tuple of transformed (train, valid, test) data.
+ """
+
+ print("Formatting train-valid-test splits.")
+
+ index = df["year"]
+ train = df.loc[index < valid_boundary]
+ valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
+ test = df.loc[index >= test_boundary]
+
+ self.set_scalers(train)
+
+ return (self.transform_inputs(data) for data in [train, valid, test])
+
+ def set_scalers(self, df):
+ """Calibrates scalers using the data supplied.
+
+ Args:
+ df: Data to use to calibrate scalers.
+ """
+ print("Setting scalers with training data...")
+
+ column_definitions = self.get_column_definition()
+ id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
+ target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
+
+ # Extract identifiers in case required
+ self.identifiers = list(df[id_column].unique())
+
+ # Format real scalers
+ real_inputs = utils.extract_cols_from_data_type(
+ DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
+ )
+
+ data = df[real_inputs].values
+ self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
+ self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
+ df[[target_column]].values
+ ) # used for predictions
+
+ # Format categorical scalers
+ categorical_inputs = utils.extract_cols_from_data_type(
+ DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
+ )
+
+ categorical_scalers = {}
+ num_classes = []
+ for col in categorical_inputs:
+ # Set all to str so that we don't have mixed integer/string columns
+ srs = df[col].apply(str)
+ categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
+ num_classes.append(srs.nunique())
+
+ # Set categorical scaler outputs
+ self._cat_scalers = categorical_scalers
+ self._num_classes_per_cat_input = num_classes
+
+ def transform_inputs(self, df):
+ """Performs feature transformations.
+
+ This includes both feature engineering, preprocessing and normalisation.
+
+ Args:
+ df: Data frame to transform.
+
+ Returns:
+ Transformed data frame.
+
+ """
+ output = df.copy()
+
+ if self._real_scalers is None and self._cat_scalers is None:
+ raise ValueError("Scalers have not been set!")
+
+ column_definitions = self.get_column_definition()
+
+ real_inputs = utils.extract_cols_from_data_type(
+ DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
+ )
+ categorical_inputs = utils.extract_cols_from_data_type(
+ DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
+ )
+
+ # Format real inputs
+ output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
+
+ # Format categorical inputs
+ for col in categorical_inputs:
+ string_df = df[col].apply(str)
+ output[col] = self._cat_scalers[col].transform(string_df)
+
+ return output
+
+ def format_predictions(self, predictions):
+ """Reverts any normalisation to give predictions in original scale.
+
+ Args:
+ predictions: Dataframe of model predictions.
+
+ Returns:
+ Data frame of unnormalised predictions.
+ """
+ output = predictions.copy()
+
+ column_names = predictions.columns
+
+ for col in column_names:
+ if col not in {"forecast_time", "identifier"}:
+ output[col] = self._target_scaler.inverse_transform(predictions[col])
+
+ return output
+
+ # Default params
+ def get_fixed_params(self):
+ """Returns fixed model parameters for experiments."""
+
+ fixed_params = {
+ "total_time_steps": 6 + 6,
+ "num_encoder_steps": 6,
+ "num_epochs": 100,
+ "early_stopping_patience": 10,
+ "multiprocessing_workers": 5,
+ }
+
+ return fixed_params
+
+ def get_default_model_params(self):
+ """Returns default optimised model parameters."""
+
+ model_params = {
+ "dropout_rate": 0.4,
+ "hidden_layer_size": 16,
+ "learning_rate": 0.0001,
+ "minibatch_size": 128,
+ "max_gradient_norm": 0.0135,
+ "num_heads": 1,
+ "stack_size": 1,
+ }
+
+ return model_params
diff --git a/examples/benchmarks/TFT/expt_settings/__init__.py b/examples/benchmarks/TFT/expt_settings/__init__.py
new file mode 100644
index 000000000..87ec3284f
--- /dev/null
+++ b/examples/benchmarks/TFT/expt_settings/__init__.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/examples/benchmarks/TFT/expt_settings/configs.py b/examples/benchmarks/TFT/expt_settings/configs.py
new file mode 100644
index 000000000..6aef0c395
--- /dev/null
+++ b/examples/benchmarks/TFT/expt_settings/configs.py
@@ -0,0 +1,95 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Default configs for TFT experiments.
+
+Contains the default output paths for data, serialised models and predictions
+for the main experiments used in the publication.
+"""
+
+import os
+
+import data_formatters.qlib_Alpha158
+
+
+class ExperimentConfig(object):
+ """Defines experiment configs and paths to outputs.
+
+ Attributes:
+ root_folder: Root folder to contain all experimental outputs.
+ experiment: Name of experiment to run.
+ data_folder: Folder to store data for experiment.
+ model_folder: Folder to store serialised models.
+ results_folder: Folder to store results.
+ data_csv_path: Path to primary data csv file used in experiment.
+ hyperparam_iterations: Default number of random search iterations for
+ experiment.
+ """
+
+ default_experiments = ["Alpha158"]
+
+ def __init__(self, experiment="volatility", root_folder=None):
+ """Creates configs based on default experiment chosen.
+
+ Args:
+ experiment: Name of experiment.
+ root_folder: Root folder to save all outputs of training.
+ """
+
+ if experiment not in self.default_experiments:
+ raise ValueError("Unrecognised experiment={}".format(experiment))
+
+ # Defines all relevant paths
+ if root_folder is None:
+ root_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "outputs")
+ print("Using root folder {}".format(root_folder))
+
+ self.root_folder = root_folder
+ self.experiment = experiment
+ self.data_folder = os.path.join(root_folder, "data", experiment)
+ self.model_folder = os.path.join(root_folder, "saved_models", experiment)
+ self.results_folder = os.path.join(root_folder, "results", experiment)
+
+ # Creates folders if they don't exist
+ for relevant_directory in [self.root_folder, self.data_folder, self.model_folder, self.results_folder]:
+ if not os.path.exists(relevant_directory):
+ os.makedirs(relevant_directory)
+
+ @property
+ def data_csv_path(self):
+ csv_map = {
+ "Alpha158": "Alpha158.csv",
+ }
+
+ return os.path.join(self.data_folder, csv_map[self.experiment])
+
+ @property
+ def hyperparam_iterations(self):
+
+ return 240 if self.experiment == "volatility" else 60
+
+ def make_data_formatter(self):
+ """Gets a data formatter object for experiment.
+
+ Returns:
+ Default DataFormatter per experiment.
+ """
+
+ data_formatter_class = {
+ "Alpha158": data_formatters.qlib_Alpha158.Alpha158Formatter,
+ }
+
+ return data_formatter_class[self.experiment]()
diff --git a/examples/benchmarks/TFT/libs/__init__.py b/examples/benchmarks/TFT/libs/__init__.py
new file mode 100644
index 000000000..87ec3284f
--- /dev/null
+++ b/examples/benchmarks/TFT/libs/__init__.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/examples/benchmarks/TFT/libs/hyperparam_opt.py b/examples/benchmarks/TFT/libs/hyperparam_opt.py
new file mode 100644
index 000000000..750fdf2c1
--- /dev/null
+++ b/examples/benchmarks/TFT/libs/hyperparam_opt.py
@@ -0,0 +1,430 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Classes used for hyperparameter optimisation.
+
+Two main classes exist:
+1) HyperparamOptManager used for optimisation on a single machine/GPU.
+2) DistributedHyperparamOptManager for multiple GPUs on different machines.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+import shutil
+import libs.utils as utils
+import numpy as np
+import pandas as pd
+
+Deque = collections.deque
+
+
+class HyperparamOptManager:
+ """Manages hyperparameter optimisation using random search for a single GPU.
+
+ Attributes:
+ param_ranges: Discrete hyperparameter range for random search.
+ results: Dataframe of validation results.
+ fixed_params: Fixed model parameters per experiment.
+ saved_params: Dataframe of parameters trained.
+ best_score: Minimum validation loss observed thus far.
+ optimal_name: Key to best configuration.
+ hyperparam_folder: Where to save optimisation outputs.
+ """
+
+ def __init__(self, param_ranges, fixed_params, model_folder, override_w_fixed_params=True):
+ """Instantiates model.
+
+ Args:
+ param_ranges: Discrete hyperparameter range for random search.
+ fixed_params: Fixed model parameters per experiment.
+ model_folder: Folder to store optimisation artifacts.
+ override_w_fixed_params: Whether to override serialsed fixed model
+ parameters with new supplied values.
+ """
+
+ self.param_ranges = param_ranges
+
+ self._max_tries = 1000
+ self.results = pd.DataFrame()
+ self.fixed_params = fixed_params
+ self.saved_params = pd.DataFrame()
+
+ self.best_score = np.Inf
+ self.optimal_name = ""
+
+ # Setup
+ # Create folder for saving if its not there
+ self.hyperparam_folder = model_folder
+ utils.create_folder_if_not_exist(self.hyperparam_folder)
+
+ self._override_w_fixed_params = override_w_fixed_params
+
+ def load_results(self):
+ """Loads results from previous hyperparameter optimisation.
+
+ Returns:
+ A boolean indicating if previous results can be loaded.
+ """
+ print("Loading results from", self.hyperparam_folder)
+
+ results_file = os.path.join(self.hyperparam_folder, "results.csv")
+ params_file = os.path.join(self.hyperparam_folder, "params.csv")
+
+ if os.path.exists(results_file) and os.path.exists(params_file):
+
+ self.results = pd.read_csv(results_file, index_col=0)
+ self.saved_params = pd.read_csv(params_file, index_col=0)
+
+ if not self.results.empty:
+ self.results.at["loss"] = self.results.loc["loss"].apply(float)
+ self.best_score = self.results.loc["loss"].min()
+
+ is_optimal = self.results.loc["loss"] == self.best_score
+ self.optimal_name = self.results.T[is_optimal].index[0]
+
+ return True
+
+ return False
+
+ def _get_params_from_name(self, name):
+ """Returns previously saved parameters given a key."""
+ params = self.saved_params
+
+ selected_params = dict(params[name])
+
+ if self._override_w_fixed_params:
+ for k in self.fixed_params:
+ selected_params[k] = self.fixed_params[k]
+
+ return selected_params
+
+ def get_best_params(self):
+ """Returns the optimal hyperparameters thus far."""
+
+ optimal_name = self.optimal_name
+
+ return self._get_params_from_name(optimal_name)
+
+ def clear(self):
+ """Clears all previous results and saved parameters."""
+ shutil.rmtree(self.hyperparam_folder)
+ os.makedirs(self.hyperparam_folder)
+ self.results = pd.DataFrame()
+ self.saved_params = pd.DataFrame()
+
+ def _check_params(self, params):
+ """Checks that parameter map is properly defined."""
+
+ valid_fields = list(self.param_ranges.keys()) + list(self.fixed_params.keys())
+ invalid_fields = [k for k in params if k not in valid_fields]
+ missing_fields = [k for k in valid_fields if k not in params]
+
+ if invalid_fields:
+ raise ValueError("Invalid Fields Found {} - Valid ones are {}".format(invalid_fields, valid_fields))
+ if missing_fields:
+ raise ValueError("Missing Fields Found {} - Valid ones are {}".format(missing_fields, valid_fields))
+
+ def _get_name(self, params):
+ """Returns a unique key for the supplied set of params."""
+
+ self._check_params(params)
+
+ fields = list(params.keys())
+ fields.sort()
+
+ return "_".join([str(params[k]) for k in fields])
+
+ def get_next_parameters(self, ranges_to_skip=None):
+ """Returns the next set of parameters to optimise.
+
+ Args:
+ ranges_to_skip: Explicitly defines a set of keys to skip.
+ """
+ if ranges_to_skip is None:
+ ranges_to_skip = set(self.results.index)
+
+ if not isinstance(self.param_ranges, dict):
+ raise ValueError("Only works for random search!")
+
+ param_range_keys = list(self.param_ranges.keys())
+ param_range_keys.sort()
+
+ def _get_next():
+ """Returns next hyperparameter set per try."""
+
+ parameters = {k: np.random.choice(self.param_ranges[k]) for k in param_range_keys}
+
+ # Adds fixed params
+ for k in self.fixed_params:
+ parameters[k] = self.fixed_params[k]
+
+ return parameters
+
+ for _ in range(self._max_tries):
+
+ parameters = _get_next()
+ name = self._get_name(parameters)
+
+ if name not in ranges_to_skip:
+ return parameters
+
+ raise ValueError("Exceeded max number of hyperparameter searches!!")
+
+ def update_score(self, parameters, loss, model, info=""):
+ """Updates the results from last optimisation run.
+
+ Args:
+ parameters: Hyperparameters used in optimisation.
+ loss: Validation loss obtained.
+ model: Model to serialised if required.
+ info: Any ancillary information to tag on to results.
+
+ Returns:
+ Boolean flag indicating if the model is the best seen so far.
+ """
+
+ if np.isnan(loss):
+ loss = np.Inf
+
+ if not os.path.isdir(self.hyperparam_folder):
+ os.makedirs(self.hyperparam_folder)
+
+ name = self._get_name(parameters)
+
+ is_optimal = self.results.empty or loss < self.best_score
+
+ # save the first model
+ if is_optimal:
+ # Try saving first, before updating info
+ if model is not None:
+ print("Optimal model found, updating")
+ model.save(self.hyperparam_folder)
+ self.best_score = loss
+ self.optimal_name = name
+
+ self.results[name] = pd.Series({"loss": loss, "info": info})
+ self.saved_params[name] = pd.Series(parameters)
+
+ self.results.to_csv(os.path.join(self.hyperparam_folder, "results.csv"))
+ self.saved_params.to_csv(os.path.join(self.hyperparam_folder, "params.csv"))
+
+ return is_optimal
+
+
+class DistributedHyperparamOptManager(HyperparamOptManager):
+ """Manages distributed hyperparameter optimisation across many gpus."""
+
+ def __init__(
+ self,
+ param_ranges,
+ fixed_params,
+ root_model_folder,
+ worker_number,
+ search_iterations=1000,
+ num_iterations_per_worker=5,
+ clear_serialised_params=False,
+ ):
+ """Instantiates optimisation manager.
+
+ This hyperparameter optimisation pre-generates #search_iterations
+ hyperparameter combinations and serialises them
+ at the start. At runtime, each worker goes through their own set of
+ parameter ranges. The pregeneration
+ allows for multiple workers to run in parallel on different machines without
+ resulting in parameter overlaps.
+
+ Args:
+ param_ranges: Discrete hyperparameter range for random search.
+ fixed_params: Fixed model parameters per experiment.
+ root_model_folder: Folder to store optimisation artifacts.
+ worker_number: Worker index definining which set of hyperparameters to
+ test.
+ search_iterations: Maximum numer of random search iterations.
+ num_iterations_per_worker: How many iterations are handled per worker.
+ clear_serialised_params: Whether to regenerate hyperparameter
+ combinations.
+ """
+
+ max_workers = int(np.ceil(search_iterations / num_iterations_per_worker))
+
+ # Sanity checks
+ if worker_number > max_workers:
+ raise ValueError(
+ "Worker number ({}) cannot be larger than the total number of workers!".format(max_workers)
+ )
+ if worker_number > search_iterations:
+ raise ValueError(
+ "Worker number ({}) cannot be larger than the max search iterations ({})!".format(
+ worker_number, search_iterations
+ )
+ )
+
+ print("*** Creating hyperparameter manager for worker {} ***".format(worker_number))
+
+ hyperparam_folder = os.path.join(root_model_folder, str(worker_number))
+ super().__init__(param_ranges, fixed_params, hyperparam_folder, override_w_fixed_params=True)
+
+ serialised_ranges_folder = os.path.join(root_model_folder, "hyperparams")
+ if clear_serialised_params:
+ print("Regenerating hyperparameter list")
+ if os.path.exists(serialised_ranges_folder):
+ shutil.rmtree(serialised_ranges_folder)
+
+ utils.create_folder_if_not_exist(serialised_ranges_folder)
+
+ self.serialised_ranges_path = os.path.join(serialised_ranges_folder, "ranges_{}.csv".format(search_iterations))
+ self.hyperparam_folder = hyperparam_folder # override
+ self.worker_num = worker_number
+ self.total_search_iterations = search_iterations
+ self.num_iterations_per_worker = num_iterations_per_worker
+ self.global_hyperparam_df = self.load_serialised_hyperparam_df()
+ self.worker_search_queue = self._get_worker_search_queue()
+
+ @property
+ def optimisation_completed(self):
+ return False if self.worker_search_queue else True
+
+ def get_next_parameters(self):
+ """Returns next dictionary of hyperparameters to optimise."""
+ param_name = self.worker_search_queue.pop()
+
+ params = self.global_hyperparam_df.loc[param_name, :].to_dict()
+
+ # Always override!
+ for k in self.fixed_params:
+ print("Overriding saved {}: {}".format(k, self.fixed_params[k]))
+
+ params[k] = self.fixed_params[k]
+
+ return params
+
+ def load_serialised_hyperparam_df(self):
+ """Loads serialsed hyperparameter ranges from file.
+
+ Returns:
+ DataFrame containing hyperparameter combinations.
+ """
+ print(
+ "Loading params for {} search iterations form {}".format(
+ self.total_search_iterations, self.serialised_ranges_path
+ )
+ )
+
+ if os.path.exists(self.serialised_ranges_folder):
+ df = pd.read_csv(self.serialised_ranges_path, index_col=0)
+ else:
+ print("Unable to load - regenerating serach ranges instead")
+ df = self.update_serialised_hyperparam_df()
+
+ return df
+
+ def update_serialised_hyperparam_df(self):
+ """Regenerates hyperparameter combinations and saves to file.
+
+ Returns:
+ DataFrame containing hyperparameter combinations.
+ """
+ search_df = self._generate_full_hyperparam_df()
+
+ print(
+ "Serialising params for {} search iterations to {}".format(
+ self.total_search_iterations, self.serialised_ranges_path
+ )
+ )
+
+ search_df.to_csv(self.serialised_ranges_path)
+
+ return search_df
+
+ def _generate_full_hyperparam_df(self):
+ """Generates actual hyperparameter combinations.
+
+ Returns:
+ DataFrame containing hyperparameter combinations.
+ """
+
+ np.random.seed(131) # for reproducibility of hyperparam list
+
+ name_list = []
+ param_list = []
+ for _ in range(self.total_search_iterations):
+ params = super().get_next_parameters(name_list)
+
+ name = self._get_name(params)
+
+ name_list.append(name)
+ param_list.append(params)
+
+ full_search_df = pd.DataFrame(param_list, index=name_list)
+
+ return full_search_df
+
+ def clear(self): # reset when cleared
+ """Clears results for hyperparameter manager and resets."""
+ super().clear()
+ self.worker_search_queue = self._get_worker_search_queue()
+
+ def load_results(self):
+ """Load results from file and queue parameter combinations to try.
+
+ Returns:
+ Boolean indicating if results were successfully loaded.
+ """
+ success = super().load_results()
+
+ if success:
+ self.worker_search_queue = self._get_worker_search_queue()
+
+ return success
+
+ def _get_worker_search_queue(self):
+ """Generates the queue of param combinations for current worker.
+
+ Returns:
+ Queue of hyperparameter combinations outstanding.
+ """
+ global_df = self.assign_worker_numbers(self.global_hyperparam_df)
+ worker_df = global_df[global_df["worker"] == self.worker_num]
+
+ left_overs = [s for s in worker_df.index if s not in self.results.columns]
+
+ return Deque(left_overs)
+
+ def assign_worker_numbers(self, df):
+ """Updates parameter combinations with the index of the worker used.
+
+ Args:
+ df: DataFrame of parameter combinations.
+
+ Returns:
+ Updated DataFrame with worker number.
+ """
+ output = df.copy()
+
+ n = self.total_search_iterations
+ batch_size = self.num_iterations_per_worker
+
+ max_worker_num = int(np.ceil(n / batch_size))
+
+ worker_idx = np.concatenate([np.tile(i + 1, self.num_iterations_per_worker) for i in range(max_worker_num)])
+
+ output["worker"] = worker_idx[: len(output)]
+
+ return output
diff --git a/examples/benchmarks/TFT/libs/tft_model.py b/examples/benchmarks/TFT/libs/tft_model.py
new file mode 100644
index 000000000..658bae60f
--- /dev/null
+++ b/examples/benchmarks/TFT/libs/tft_model.py
@@ -0,0 +1,1280 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Temporal Fusion Transformer Model.
+
+Contains the full TFT architecture and associated components. Defines functions
+for training, evaluation and prediction using simple Pandas Dataframe inputs.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gc
+import json
+import os
+import shutil
+
+import data_formatters.base
+import libs.utils as utils
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+
+# Layer definitions.
+concat = tf.keras.backend.concatenate
+stack = tf.keras.backend.stack
+K = tf.keras.backend
+Add = tf.keras.layers.Add
+LayerNorm = tf.keras.layers.LayerNormalization
+Dense = tf.keras.layers.Dense
+Multiply = tf.keras.layers.Multiply
+Dropout = tf.keras.layers.Dropout
+Activation = tf.keras.layers.Activation
+Lambda = tf.keras.layers.Lambda
+
+# Default input types.
+InputTypes = data_formatters.base.InputTypes
+
+
+# Layer utility functions.
+def linear_layer(size, activation=None, use_time_distributed=False, use_bias=True):
+ """Returns simple Keras linear layer.
+
+ Args:
+ size: Output size
+ activation: Activation function to apply if required
+ use_time_distributed: Whether to apply layer across time
+ use_bias: Whether bias should be included in layer
+ """
+ linear = tf.keras.layers.Dense(size, activation=activation, use_bias=use_bias)
+ if use_time_distributed:
+ linear = tf.keras.layers.TimeDistributed(linear)
+ return linear
+
+
+def apply_mlp(
+ inputs, hidden_size, output_size, output_activation=None, hidden_activation="tanh", use_time_distributed=False
+):
+ """Applies simple feed-forward network to an input.
+
+ Args:
+ inputs: MLP inputs
+ hidden_size: Hidden state size
+ output_size: Output size of MLP
+ output_activation: Activation function to apply on output
+ hidden_activation: Activation function to apply on input
+ use_time_distributed: Whether to apply across time
+
+ Returns:
+ Tensor for MLP outputs.
+ """
+ if use_time_distributed:
+ hidden = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_size, activation=hidden_activation))(
+ inputs
+ )
+ return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(output_size, activation=output_activation))(hidden)
+ else:
+ hidden = tf.keras.layers.Dense(hidden_size, activation=hidden_activation)(inputs)
+ return tf.keras.layers.Dense(output_size, activation=output_activation)(hidden)
+
+
+def apply_gating_layer(x, hidden_layer_size, dropout_rate=None, use_time_distributed=True, activation=None):
+ """Applies a Gated Linear Unit (GLU) to an input.
+
+ Args:
+ x: Input to gating layer
+ hidden_layer_size: Dimension of GLU
+ dropout_rate: Dropout rate to apply if any
+ use_time_distributed: Whether to apply across time
+ activation: Activation function to apply to the linear feature transform if
+ necessary
+
+ Returns:
+ Tuple of tensors for: (GLU output, gate)
+ """
+
+ if dropout_rate is not None:
+ x = tf.keras.layers.Dropout(dropout_rate)(x)
+
+ if use_time_distributed:
+ activation_layer = tf.keras.layers.TimeDistributed(
+ tf.keras.layers.Dense(hidden_layer_size, activation=activation)
+ )(x)
+ gated_layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid"))(x)
+ else:
+ activation_layer = tf.keras.layers.Dense(hidden_layer_size, activation=activation)(x)
+ gated_layer = tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid")(x)
+
+ return tf.keras.layers.Multiply()([activation_layer, gated_layer]), gated_layer
+
+
+def add_and_norm(x_list):
+ """Applies skip connection followed by layer normalisation.
+
+ Args:
+ x_list: List of inputs to sum for skip connection
+
+ Returns:
+ Tensor output from layer.
+ """
+ tmp = Add()(x_list)
+ tmp = LayerNorm()(tmp)
+ return tmp
+
+
+def gated_residual_network(
+ x,
+ hidden_layer_size,
+ output_size=None,
+ dropout_rate=None,
+ use_time_distributed=True,
+ additional_context=None,
+ return_gate=False,
+):
+ """Applies the gated residual network (GRN) as defined in paper.
+
+ Args:
+ x: Network inputs
+ hidden_layer_size: Internal state size
+ output_size: Size of output layer
+ dropout_rate: Dropout rate if dropout is applied
+ use_time_distributed: Whether to apply network across time dimension
+ additional_context: Additional context vector to use if relevant
+ return_gate: Whether to return GLU gate for diagnostic purposes
+
+ Returns:
+ Tuple of tensors for: (GRN output, GLU gate)
+ """
+
+ # Setup skip connection
+ if output_size is None:
+ output_size = hidden_layer_size
+ skip = x
+ else:
+ linear = Dense(output_size)
+ if use_time_distributed:
+ linear = tf.keras.layers.TimeDistributed(linear)
+ skip = linear(x)
+
+ # Apply feedforward network
+ hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(x)
+ if additional_context is not None:
+ hidden = hidden + linear_layer(
+ hidden_layer_size, activation=None, use_time_distributed=use_time_distributed, use_bias=False
+ )(additional_context)
+ hidden = tf.keras.layers.Activation("elu")(hidden)
+ hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(hidden)
+
+ gating_layer, gate = apply_gating_layer(
+ hidden, output_size, dropout_rate=dropout_rate, use_time_distributed=use_time_distributed, activation=None
+ )
+
+ if return_gate:
+ return add_and_norm([skip, gating_layer]), gate
+ else:
+ return add_and_norm([skip, gating_layer])
+
+
+# Attention Components.
+def get_decoder_mask(self_attn_inputs):
+ """Returns causal mask to apply for self-attention layer.
+
+ Args:
+ self_attn_inputs: Inputs to self attention layer to determine mask shape
+ """
+ len_s = tf.shape(self_attn_inputs)[1]
+ bs = tf.shape(self_attn_inputs)[:1]
+ mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1)
+ return mask
+
+
+class ScaledDotProductAttention:
+ """Defines scaled dot product attention layer.
+
+ Attributes:
+ dropout: Dropout rate to use
+ activation: Normalisation function for scaled dot product attention (e.g.
+ softmax by default)
+ """
+
+ def __init__(self, attn_dropout=0.0):
+ self.dropout = Dropout(attn_dropout)
+ self.activation = Activation("softmax")
+
+ def __call__(self, q, k, v, mask):
+ """Applies scaled dot product attention.
+
+ Args:
+ q: Queries
+ k: Keys
+ v: Values
+ mask: Masking if required -- sets softmax to very large value
+
+ Returns:
+ Tuple of (layer outputs, attention weights)
+ """
+ temper = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype="float32"))
+ attn = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]) / temper)([q, k]) # shape=(batch, q, k)
+ if mask is not None:
+ mmask = Lambda(lambda x: (-1e9) * (1.0 - K.cast(x, "float32")))(mask) # setting to infinity
+ attn = Add()([attn, mmask])
+ attn = self.activation(attn)
+ attn = self.dropout(attn)
+ output = Lambda(lambda x: K.batch_dot(x[0], x[1]))([attn, v])
+ return output, attn
+
+
+class InterpretableMultiHeadAttention:
+ """Defines interpretable multi-head attention layer.
+
+ Attributes:
+ n_head: Number of heads
+ d_k: Key/query dimensionality per head
+ d_v: Value dimensionality
+ dropout: Dropout rate to apply
+ qs_layers: List of queries across heads
+ ks_layers: List of keys across heads
+ vs_layers: List of values across heads
+ attention: Scaled dot product attention layer
+ w_o: Output weight matrix to project internal state to the original TFT
+ state size
+ """
+
+ def __init__(self, n_head, d_model, dropout):
+ """Initialises layer.
+
+ Args:
+ n_head: Number of heads
+ d_model: TFT state dimensionality
+ dropout: Dropout discard rate
+ """
+
+ self.n_head = n_head
+ self.d_k = self.d_v = d_k = d_v = d_model // n_head
+ self.dropout = dropout
+
+ self.qs_layers = []
+ self.ks_layers = []
+ self.vs_layers = []
+
+ # Use same value layer to facilitate interp
+ vs_layer = Dense(d_v, use_bias=False)
+
+ for _ in range(n_head):
+ self.qs_layers.append(Dense(d_k, use_bias=False))
+ self.ks_layers.append(Dense(d_k, use_bias=False))
+ self.vs_layers.append(vs_layer) # use same vs_layer
+
+ self.attention = ScaledDotProductAttention()
+ self.w_o = Dense(d_model, use_bias=False)
+
+ def __call__(self, q, k, v, mask=None):
+ """Applies interpretable multihead attention.
+
+ Using T to denote the number of time steps fed into the transformer.
+
+ Args:
+ q: Query tensor of shape=(?, T, d_model)
+ k: Key of shape=(?, T, d_model)
+ v: Values of shape=(?, T, d_model)
+ mask: Masking if required with shape=(?, T, T)
+
+ Returns:
+ Tuple of (layer outputs, attention weights)
+ """
+ n_head = self.n_head
+
+ heads = []
+ attns = []
+ for i in range(n_head):
+ qs = self.qs_layers[i](q)
+ ks = self.ks_layers[i](k)
+ vs = self.vs_layers[i](v)
+ head, attn = self.attention(qs, ks, vs, mask)
+
+ head_dropout = Dropout(self.dropout)(head)
+ heads.append(head_dropout)
+ attns.append(attn)
+ head = K.stack(heads) if n_head > 1 else heads[0]
+ attn = K.stack(attns)
+
+ outputs = K.mean(head, axis=0) if n_head > 1 else head
+ outputs = self.w_o(outputs)
+ outputs = Dropout(self.dropout)(outputs) # output dropout
+
+ return outputs, attn
+
+
+class TFTDataCache(object):
+ """Caches data for the TFT."""
+
+ _data_cache = {}
+
+ @classmethod
+ def update(cls, data, key):
+ """Updates cached data.
+
+ Args:
+ data: Source to update
+ key: Key to dictionary location
+ """
+ cls._data_cache[key] = data
+
+ @classmethod
+ def get(cls, key):
+ """Returns data stored at key location."""
+ return cls._data_cache[key].copy()
+
+ @classmethod
+ def contains(cls, key):
+ """Retuns boolean indicating whether key is present in cache."""
+
+ return key in cls._data_cache
+
+
+# TFT model definitions.
+class TemporalFusionTransformer(object):
+ """Defines Temporal Fusion Transformer.
+
+ Attributes:
+ name: Name of model
+ time_steps: Total number of input time steps per forecast date (i.e. Width
+ of Temporal fusion decoder N)
+ input_size: Total number of inputs
+ output_size: Total number of outputs
+ category_counts: Number of categories per categorical variable
+ n_multiprocessing_workers: Number of workers to use for parallel
+ computations
+ column_definition: List of tuples of (string, DataType, InputType) that
+ define each column
+ quantiles: Quantiles to forecast for TFT
+ use_cudnn: Whether to use Keras CuDNNLSTM or standard LSTM layers
+ hidden_layer_size: Internal state size of TFT
+ dropout_rate: Dropout discard rate
+ max_gradient_norm: Maximum norm for gradient clipping
+ learning_rate: Initial learning rate of ADAM optimizer
+ minibatch_size: Size of minibatches for training
+ num_epochs: Maximum number of epochs for training
+ early_stopping_patience: Maximum number of iterations of non-improvement
+ before early stopping kicks in
+ num_encoder_steps: Size of LSTM encoder -- i.e. number of past time steps
+ before forecast date to use
+ num_stacks: Number of self-attention layers to apply (default is 1 for basic
+ TFT)
+ num_heads: Number of heads for interpretable mulit-head attention
+ model: Keras model for TFT
+ """
+
+ def __init__(self, raw_params, use_cudnn=False):
+ """Builds TFT from parameters.
+
+ Args:
+ raw_params: Parameters to define TFT
+ use_cudnn: Whether to use CUDNN GPU optimised LSTM
+ """
+
+ self.name = self.__class__.__name__
+
+ params = dict(raw_params) # copy locally
+
+ # Data parameters
+ self.time_steps = int(params["total_time_steps"])
+ self.input_size = int(params["input_size"])
+ self.output_size = int(params["output_size"])
+ self.category_counts = json.loads(str(params["category_counts"]))
+ self.n_multiprocessing_workers = int(params["multiprocessing_workers"])
+
+ # Relevant indices for TFT
+ self._input_obs_loc = json.loads(str(params["input_obs_loc"]))
+ self._static_input_loc = json.loads(str(params["static_input_loc"]))
+ self._known_regular_input_idx = json.loads(str(params["known_regular_inputs"]))
+ self._known_categorical_input_idx = json.loads(str(params["known_categorical_inputs"]))
+
+ self.column_definition = params["column_definition"]
+
+ # Network params
+ self.quantiles = [0.1, 0.5, 0.9]
+ self.use_cudnn = use_cudnn # Whether to use GPU optimised LSTM
+ self.hidden_layer_size = int(params["hidden_layer_size"])
+ self.dropout_rate = float(params["dropout_rate"])
+ self.max_gradient_norm = float(params["max_gradient_norm"])
+ self.learning_rate = float(params["learning_rate"])
+ self.minibatch_size = int(params["minibatch_size"])
+ self.num_epochs = int(params["num_epochs"])
+ self.early_stopping_patience = int(params["early_stopping_patience"])
+
+ self.num_encoder_steps = int(params["num_encoder_steps"])
+ self.num_stacks = int(params["stack_size"])
+ self.num_heads = int(params["num_heads"])
+
+ # Serialisation options
+ self._temp_folder = os.path.join(params["model_folder"], "tmp")
+ self.reset_temp_folder()
+
+ # Extra components to store Tensorflow nodes for attention computations
+ self._input_placeholder = None
+ self._attention_components = None
+ self._prediction_parts = None
+
+ print("*** {} params ***".format(self.name))
+ for k in params:
+ print("# {} = {}".format(k, params[k]))
+
+ # Build model
+ self.model = self.build_model()
+
+ def get_tft_embeddings(self, all_inputs):
+ """Transforms raw inputs to embeddings.
+
+ Applies linear transformation onto continuous variables and uses embeddings
+ for categorical variables.
+
+ Args:
+ all_inputs: Inputs to transform
+
+ Returns:
+ Tensors for transformed inputs.
+ """
+
+ time_steps = self.time_steps
+
+ # Sanity checks
+ for i in self._known_regular_input_idx:
+ if i in self._input_obs_loc:
+ raise ValueError("Observation cannot be known a priori!")
+ for i in self._input_obs_loc:
+ if i in self._static_input_loc:
+ raise ValueError("Observation cannot be static!")
+
+ if all_inputs.get_shape().as_list()[-1] != self.input_size:
+ raise ValueError(
+ "Illegal number of inputs! Inputs observed={}, expected={}".format(
+ all_inputs.get_shape().as_list()[-1], self.input_size
+ )
+ )
+
+ num_categorical_variables = len(self.category_counts)
+ num_regular_variables = self.input_size - num_categorical_variables
+
+ embedding_sizes = [self.hidden_layer_size for i, size in enumerate(self.category_counts)]
+
+ embeddings = []
+ for i in range(num_categorical_variables):
+
+ embedding = tf.keras.Sequential(
+ [
+ tf.keras.layers.InputLayer([time_steps]),
+ tf.keras.layers.Embedding(
+ self.category_counts[i], embedding_sizes[i], input_length=time_steps, dtype=tf.float32
+ ),
+ ]
+ )
+ embeddings.append(embedding)
+
+ regular_inputs, categorical_inputs = (
+ all_inputs[:, :, :num_regular_variables],
+ all_inputs[:, :, num_regular_variables:],
+ )
+
+ embedded_inputs = [embeddings[i](categorical_inputs[Ellipsis, i]) for i in range(num_categorical_variables)]
+
+ # Static inputs
+ if self._static_input_loc:
+ static_inputs = [
+ tf.keras.layers.Dense(self.hidden_layer_size)(regular_inputs[:, 0, i : i + 1])
+ for i in range(num_regular_variables)
+ if i in self._static_input_loc
+ ] + [
+ embedded_inputs[i][:, 0, :]
+ for i in range(num_categorical_variables)
+ if i + num_regular_variables in self._static_input_loc
+ ]
+ static_inputs = tf.keras.backend.stack(static_inputs, axis=1)
+
+ else:
+ static_inputs = None
+
+ def convert_real_to_embedding(x):
+ """Applies linear transformation for time-varying inputs."""
+ return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.hidden_layer_size))(x)
+
+ # Targets
+ obs_inputs = tf.keras.backend.stack(
+ [convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1]) for i in self._input_obs_loc], axis=-1
+ )
+
+ # Observed (a prioir unknown) inputs
+ wired_embeddings = []
+ for i in range(num_categorical_variables):
+ if i not in self._known_categorical_input_idx and i + num_regular_variables not in self._input_obs_loc:
+ e = embeddings[i](categorical_inputs[:, :, i])
+ wired_embeddings.append(e)
+
+ unknown_inputs = []
+ for i in range(regular_inputs.shape[-1]):
+ if i not in self._known_regular_input_idx and i not in self._input_obs_loc:
+ e = convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1])
+ unknown_inputs.append(e)
+
+ if unknown_inputs + wired_embeddings:
+ unknown_inputs = tf.keras.backend.stack(unknown_inputs + wired_embeddings, axis=-1)
+ else:
+ unknown_inputs = None
+
+ # A priori known inputs
+ known_regular_inputs = [
+ convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1])
+ for i in self._known_regular_input_idx
+ if i not in self._static_input_loc
+ ]
+ known_categorical_inputs = [
+ embedded_inputs[i]
+ for i in self._known_categorical_input_idx
+ if i + num_regular_variables not in self._static_input_loc
+ ]
+
+ known_combined_layer = tf.keras.backend.stack(known_regular_inputs + known_categorical_inputs, axis=-1)
+
+ return unknown_inputs, known_combined_layer, obs_inputs, static_inputs
+
+ def _get_single_col_by_type(self, input_type):
+ """Returns name of single column for input type."""
+
+ return utils.get_single_col_by_input_type(input_type, self.column_definition)
+
+ def training_data_cached(self):
+ """Returns boolean indicating if training data has been cached."""
+
+ return TFTDataCache.contains("train") and TFTDataCache.contains("valid")
+
+ def cache_batched_data(self, data, cache_key, num_samples=-1):
+ """Batches and caches data once for using during training.
+
+ Args:
+ data: Data to batch and cache
+ cache_key: Key used for cache
+ num_samples: Maximum number of samples to extract (-1 to use all data)
+ """
+
+ if num_samples > 0:
+ TFTDataCache.update(self._batch_sampled_data(data, max_samples=num_samples), cache_key)
+ else:
+ TFTDataCache.update(self._batch_data(data), cache_key)
+
+ print('Cached data "{}" updated'.format(cache_key))
+
+ def _batch_sampled_data(self, data, max_samples):
+ """Samples segments into a compatible format.
+
+ Args:
+ data: Sources data to sample and batch
+ max_samples: Maximum number of samples in batch
+
+ Returns:
+ Dictionary of batched data with the maximum samples specified.
+ """
+
+ if max_samples < 1:
+ raise ValueError("Illegal number of samples specified! samples={}".format(max_samples))
+
+ id_col = self._get_single_col_by_type(InputTypes.ID)
+ time_col = self._get_single_col_by_type(InputTypes.TIME)
+
+ data.sort_values(by=[id_col, time_col], inplace=True)
+
+ print("Getting valid sampling locations.")
+ valid_sampling_locations = []
+ split_data_map = {}
+ for identifier, df in data.groupby(id_col):
+ print("Getting locations for {}".format(identifier))
+ num_entries = len(df)
+ if num_entries >= self.time_steps:
+ valid_sampling_locations += [
+ (identifier, self.time_steps + i) for i in range(num_entries - self.time_steps + 1)
+ ]
+ split_data_map[identifier] = df
+
+ inputs = np.zeros((max_samples, self.time_steps, self.input_size))
+ outputs = np.zeros((max_samples, self.time_steps, self.output_size))
+ time = np.empty((max_samples, self.time_steps, 1), dtype=object)
+ identifiers = np.empty((max_samples, self.time_steps, 1), dtype=object)
+
+ if max_samples > 0 and len(valid_sampling_locations) > max_samples:
+ print("Extracting {} samples...".format(max_samples))
+ ranges = [
+ valid_sampling_locations[i]
+ for i in np.random.choice(len(valid_sampling_locations), max_samples, replace=False)
+ ]
+ else:
+ print("Max samples={} exceeds # available segments={}".format(max_samples, len(valid_sampling_locations)))
+ ranges = valid_sampling_locations
+
+ id_col = self._get_single_col_by_type(InputTypes.ID)
+ time_col = self._get_single_col_by_type(InputTypes.TIME)
+ target_col = self._get_single_col_by_type(InputTypes.TARGET)
+ input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
+
+ for i, tup in enumerate(ranges):
+ if (i + 1 % 1000) == 0:
+ print(i + 1, "of", max_samples, "samples done...")
+ identifier, start_idx = tup
+ sliced = split_data_map[identifier].iloc[start_idx - self.time_steps : start_idx]
+ inputs[i, :, :] = sliced[input_cols]
+ outputs[i, :, :] = sliced[[target_col]]
+ time[i, :, 0] = sliced[time_col]
+ identifiers[i, :, 0] = sliced[id_col]
+
+ sampled_data = {
+ "inputs": inputs,
+ "outputs": outputs[:, self.num_encoder_steps :, :],
+ "active_entries": np.ones_like(outputs[:, self.num_encoder_steps :, :]),
+ "time": time,
+ "identifier": identifiers,
+ }
+
+ return sampled_data
+
+ def _batch_data(self, data):
+ """Batches data for training.
+
+ Converts raw dataframe from a 2-D tabular format to a batched 3-D array
+ to feed into Keras model.
+
+ Args:
+ data: DataFrame to batch
+
+ Returns:
+ Batched Numpy array with shape=(?, self.time_steps, self.input_size)
+ """
+
+ # Functions.
+ def _batch_single_entity(input_data):
+ time_steps = len(input_data)
+ lags = self.time_steps
+ x = input_data.values
+ if time_steps >= lags:
+ return np.stack([x[i : time_steps - (lags - 1) + i, :] for i in range(lags)], axis=1)
+
+ else:
+ return None
+
+ id_col = self._get_single_col_by_type(InputTypes.ID)
+ time_col = self._get_single_col_by_type(InputTypes.TIME)
+ target_col = self._get_single_col_by_type(InputTypes.TARGET)
+ input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
+
+ data_map = {}
+ for _, sliced in data.groupby(id_col):
+
+ col_mappings = {"identifier": [id_col], "time": [time_col], "outputs": [target_col], "inputs": input_cols}
+
+ for k in col_mappings:
+ cols = col_mappings[k]
+ arr = _batch_single_entity(sliced[cols].copy())
+
+ if k not in data_map:
+ data_map[k] = [arr]
+ else:
+ data_map[k].append(arr)
+
+ # Combine all data
+ for k in data_map:
+ # Wendi: Avoid returning None when the length is not enough
+ data_map[k] = np.concatenate([i for i in data_map[k] if i is not None], axis=0)
+
+ # Shorten target so we only get decoder steps
+ data_map["outputs"] = data_map["outputs"][:, self.num_encoder_steps :, :]
+
+ active_entries = np.ones_like(data_map["outputs"])
+ if "active_entries" not in data_map:
+ data_map["active_entries"] = active_entries
+ else:
+ data_map["active_entries"].append(active_entries)
+
+ return data_map
+
+ def _get_active_locations(self, x):
+ """Formats sample weights for Keras training."""
+ return (np.sum(x, axis=-1) > 0.0) * 1.0
+
+ def _build_base_graph(self):
+ """Returns graph defining layers of the TFT."""
+
+ # Size definitions.
+ time_steps = self.time_steps
+ combined_input_size = self.input_size
+ encoder_steps = self.num_encoder_steps
+
+ # Inputs.
+ all_inputs = tf.keras.layers.Input(
+ shape=(
+ time_steps,
+ combined_input_size,
+ )
+ )
+
+ unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)
+
+ # Isolate known and observed historical inputs.
+ if unknown_inputs is not None:
+ historical_inputs = concat(
+ [
+ unknown_inputs[:, :encoder_steps, :],
+ known_combined_layer[:, :encoder_steps, :],
+ obs_inputs[:, :encoder_steps, :],
+ ],
+ axis=-1,
+ )
+ else:
+ historical_inputs = concat(
+ [known_combined_layer[:, :encoder_steps, :], obs_inputs[:, :encoder_steps, :]], axis=-1
+ )
+
+ # Isolate only known future inputs.
+ future_inputs = known_combined_layer[:, encoder_steps:, :]
+
+ def static_combine_and_mask(embedding):
+ """Applies variable selection network to static inputs.
+
+ Args:
+ embedding: Transformed static inputs
+
+ Returns:
+ Tensor output for variable selection network
+ """
+
+ # Add temporal features
+ _, num_static, _ = embedding.get_shape().as_list()
+
+ flatten = tf.keras.layers.Flatten()(embedding)
+
+ # Nonlinear transformation with gated residual network.
+ mlp_outputs = gated_residual_network(
+ flatten,
+ self.hidden_layer_size,
+ output_size=num_static,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=False,
+ additional_context=None,
+ )
+
+ sparse_weights = tf.keras.layers.Activation("softmax")(mlp_outputs)
+ sparse_weights = K.expand_dims(sparse_weights, axis=-1)
+
+ trans_emb_list = []
+ for i in range(num_static):
+ e = gated_residual_network(
+ embedding[:, i : i + 1, :],
+ self.hidden_layer_size,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=False,
+ )
+ trans_emb_list.append(e)
+
+ transformed_embedding = concat(trans_emb_list, axis=1)
+
+ combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding])
+
+ static_vec = K.sum(combined, axis=1)
+
+ return static_vec, sparse_weights
+
+ static_encoder, static_weights = static_combine_and_mask(static_inputs)
+
+ static_context_variable_selection = gated_residual_network(
+ static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
+ )
+ static_context_enrichment = gated_residual_network(
+ static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
+ )
+ static_context_state_h = gated_residual_network(
+ static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
+ )
+ static_context_state_c = gated_residual_network(
+ static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
+ )
+
+ def lstm_combine_and_mask(embedding):
+ """Apply temporal variable selection networks.
+
+ Args:
+ embedding: Transformed inputs.
+
+ Returns:
+ Processed tensor outputs.
+ """
+
+ # Add temporal features
+ _, time_steps, embedding_dim, num_inputs = embedding.get_shape().as_list()
+
+ flatten = K.reshape(embedding, [-1, time_steps, embedding_dim * num_inputs])
+
+ expanded_static_context = K.expand_dims(static_context_variable_selection, axis=1)
+
+ # Variable selection weights
+ mlp_outputs, static_gate = gated_residual_network(
+ flatten,
+ self.hidden_layer_size,
+ output_size=num_inputs,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=True,
+ additional_context=expanded_static_context,
+ return_gate=True,
+ )
+
+ sparse_weights = tf.keras.layers.Activation("softmax")(mlp_outputs)
+ sparse_weights = tf.expand_dims(sparse_weights, axis=2)
+
+ # Non-linear Processing & weight application
+ trans_emb_list = []
+ for i in range(num_inputs):
+ grn_output = gated_residual_network(
+ embedding[Ellipsis, i],
+ self.hidden_layer_size,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=True,
+ )
+ trans_emb_list.append(grn_output)
+
+ transformed_embedding = stack(trans_emb_list, axis=-1)
+
+ combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding])
+ temporal_ctx = K.sum(combined, axis=-1)
+
+ return temporal_ctx, sparse_weights, static_gate
+
+ historical_features, historical_flags, _ = lstm_combine_and_mask(historical_inputs)
+ future_features, future_flags, _ = lstm_combine_and_mask(future_inputs)
+
+ # LSTM layer
+ def get_lstm(return_state):
+ """Returns LSTM cell initialized with default parameters."""
+ if self.use_cudnn:
+ lstm = tf.keras.layers.CuDNNLSTM(
+ self.hidden_layer_size,
+ return_sequences=True,
+ return_state=return_state,
+ stateful=False,
+ )
+ else:
+ lstm = tf.keras.layers.LSTM(
+ self.hidden_layer_size,
+ return_sequences=True,
+ return_state=return_state,
+ stateful=False,
+ # Additional params to ensure LSTM matches CuDNN, See TF 2.0 :
+ # (https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM)
+ activation="tanh",
+ recurrent_activation="sigmoid",
+ recurrent_dropout=0,
+ unroll=False,
+ use_bias=True,
+ )
+ return lstm
+
+ history_lstm, state_h, state_c = get_lstm(return_state=True)(
+ historical_features, initial_state=[static_context_state_h, static_context_state_c]
+ )
+
+ future_lstm = get_lstm(return_state=False)(future_features, initial_state=[state_h, state_c])
+
+ lstm_layer = concat([history_lstm, future_lstm], axis=1)
+
+ # Apply gated skip connection
+ input_embeddings = concat([historical_features, future_features], axis=1)
+
+ lstm_layer, _ = apply_gating_layer(lstm_layer, self.hidden_layer_size, self.dropout_rate, activation=None)
+ temporal_feature_layer = add_and_norm([lstm_layer, input_embeddings])
+
+ # Static enrichment layers
+ expanded_static_context = K.expand_dims(static_context_enrichment, axis=1)
+ enriched, _ = gated_residual_network(
+ temporal_feature_layer,
+ self.hidden_layer_size,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=True,
+ additional_context=expanded_static_context,
+ return_gate=True,
+ )
+
+ # Decoder self attention
+ self_attn_layer = InterpretableMultiHeadAttention(
+ self.num_heads, self.hidden_layer_size, dropout=self.dropout_rate
+ )
+
+ mask = get_decoder_mask(enriched)
+ x, self_att = self_attn_layer(enriched, enriched, enriched, mask=mask)
+
+ x, _ = apply_gating_layer(x, self.hidden_layer_size, dropout_rate=self.dropout_rate, activation=None)
+ x = add_and_norm([x, enriched])
+
+ # Nonlinear processing on outputs
+ decoder = gated_residual_network(
+ x, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=True
+ )
+
+ # Final skip connection
+ decoder, _ = apply_gating_layer(decoder, self.hidden_layer_size, activation=None)
+ transformer_layer = add_and_norm([decoder, temporal_feature_layer])
+
+ # Attention components for explainability
+ attention_components = {
+ # Temporal attention weights
+ "decoder_self_attn": self_att,
+ # Static variable selection weights
+ "static_flags": static_weights[Ellipsis, 0],
+ # Variable selection weights of past inputs
+ "historical_flags": historical_flags[Ellipsis, 0, :],
+ # Variable selection weights of future inputs
+ "future_flags": future_flags[Ellipsis, 0, :],
+ }
+
+ return transformer_layer, all_inputs, attention_components
+
+ def build_model(self):
+ """Build model and defines training losses.
+
+ Returns:
+ Fully defined Keras model.
+ """
+
+ with tf.variable_scope(self.name):
+
+ transformer_layer, all_inputs, attention_components = self._build_base_graph()
+
+ outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.output_size * len(self.quantiles)))(
+ transformer_layer[Ellipsis, self.num_encoder_steps :, :]
+ )
+
+ self._attention_components = attention_components
+
+ adam = tf.keras.optimizers.Adam(lr=self.learning_rate, clipnorm=self.max_gradient_norm)
+
+ model = tf.keras.Model(inputs=all_inputs, outputs=outputs)
+
+ print(model.summary())
+
+ valid_quantiles = self.quantiles
+ output_size = self.output_size
+
+ class QuantileLossCalculator(object):
+ """Computes the combined quantile loss for prespecified quantiles.
+
+ Attributes:
+ quantiles: Quantiles to compute losses
+ """
+
+ def __init__(self, quantiles):
+ """Initializes computer with quantiles for loss calculations.
+
+ Args:
+ quantiles: Quantiles to use for computations.
+ """
+ self.quantiles = quantiles
+
+ def quantile_loss(self, a, b):
+ """Returns quantile loss for specified quantiles.
+
+ Args:
+ a: Targets
+ b: Predictions
+ """
+ quantiles_used = set(self.quantiles)
+
+ loss = 0.0
+ for i, quantile in enumerate(valid_quantiles):
+ if quantile in quantiles_used:
+ loss += utils.tensorflow_quantile_loss(
+ a[Ellipsis, output_size * i : output_size * (i + 1)],
+ b[Ellipsis, output_size * i : output_size * (i + 1)],
+ quantile,
+ )
+ return loss
+
+ quantile_loss = QuantileLossCalculator(valid_quantiles).quantile_loss
+
+ model.compile(loss=quantile_loss, optimizer=adam, sample_weight_mode="temporal")
+
+ self._input_placeholder = all_inputs
+
+ return model
+
+ def fit(self, train_df=None, valid_df=None):
+ """Fits deep neural network for given training and validation data.
+
+ Args:
+ train_df: DataFrame for training data
+ valid_df: DataFrame for validation data
+ """
+
+ print("*** Fitting {} ***".format(self.name))
+
+ # Add relevant callbacks
+ callbacks = [
+ tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=self.early_stopping_patience, min_delta=1e-4),
+ tf.keras.callbacks.ModelCheckpoint(
+ filepath=self.get_keras_saved_path(self._temp_folder),
+ monitor="val_loss",
+ save_best_only=True,
+ save_weights_only=True,
+ ),
+ tf.keras.callbacks.TerminateOnNaN(),
+ ]
+
+ print("Getting batched_data")
+ if train_df is None:
+ print("Using cached training data")
+ train_data = TFTDataCache.get("train")
+ else:
+ train_data = self._batch_data(train_df)
+
+ if valid_df is None:
+ print("Using cached validation data")
+ valid_data = TFTDataCache.get("valid")
+ else:
+ valid_data = self._batch_data(valid_df)
+
+ print("Using keras standard fit")
+
+ def _unpack(data):
+ return data["inputs"], data["outputs"], self._get_active_locations(data["active_entries"])
+
+ # Unpack without sample weights
+ data, labels, active_flags = _unpack(train_data)
+ val_data, val_labels, val_flags = _unpack(valid_data)
+
+ all_callbacks = callbacks
+
+ self.model.fit(
+ x=data,
+ y=np.concatenate([labels, labels, labels], axis=-1),
+ sample_weight=active_flags,
+ epochs=self.num_epochs,
+ batch_size=self.minibatch_size,
+ validation_data=(val_data, np.concatenate([val_labels, val_labels, val_labels], axis=-1), val_flags),
+ callbacks=all_callbacks,
+ shuffle=True,
+ use_multiprocessing=True,
+ workers=self.n_multiprocessing_workers,
+ )
+
+ # Load best checkpoint again
+ tmp_checkpont = self.get_keras_saved_path(self._temp_folder)
+ if os.path.exists(tmp_checkpont):
+ self.load(self._temp_folder, use_keras_loadings=True)
+
+ else:
+ print("Cannot load from {}, skipping ...".format(self._temp_folder))
+
+ def evaluate(self, data=None, eval_metric="loss"):
+ """Applies evaluation metric to the training data.
+
+ Args:
+ data: Dataframe for evaluation
+ eval_metric: Evaluation metic to return, based on model definition.
+
+ Returns:
+ Computed evaluation loss.
+ """
+
+ if data is None:
+ print("Using cached validation data")
+ raw_data = TFTDataCache.get("valid")
+ else:
+ raw_data = self._batch_data(data)
+
+ inputs = raw_data["inputs"]
+ outputs = raw_data["outputs"]
+ active_entries = self._get_active_locations(raw_data["active_entries"])
+
+ metric_values = self.model.evaluate(
+ x=inputs,
+ y=np.concatenate([outputs, outputs, outputs], axis=-1),
+ sample_weight=active_entries,
+ workers=16,
+ use_multiprocessing=True,
+ )
+
+ metrics = pd.Series(metric_values, self.model.metrics_names)
+
+ return metrics[eval_metric]
+
+ def predict(self, df, return_targets=False):
+ """Computes predictions for a given input dataset.
+
+ Args:
+ df: Input dataframe
+ return_targets: Whether to also return outputs aligned with predictions to
+ faciliate evaluation
+
+ Returns:
+ Input dataframe or tuple of (input dataframe, algined output dataframe).
+ """
+
+ data = self._batch_data(df)
+
+ inputs = data["inputs"]
+ time = data["time"]
+ identifier = data["identifier"]
+ outputs = data["outputs"]
+
+ combined = self.model.predict(inputs, workers=16, use_multiprocessing=True, batch_size=self.minibatch_size)
+
+ # Format output_csv
+ if self.output_size != 1:
+ raise NotImplementedError("Current version only supports 1D targets!")
+
+ def format_outputs(prediction):
+ """Returns formatted dataframes for prediction."""
+
+ flat_prediction = pd.DataFrame(
+ prediction[:, :, 0], columns=["t+{}".format(i) for i in range(self.time_steps - self.num_encoder_steps)]
+ )
+ cols = list(flat_prediction.columns)
+ flat_prediction["forecast_time"] = time[:, self.num_encoder_steps - 1, 0]
+ flat_prediction["identifier"] = identifier[:, 0, 0]
+
+ # Arrange in order
+ return flat_prediction[["forecast_time", "identifier"] + cols]
+
+ # Extract predictions for each quantile into different entries
+ process_map = {
+ "p{}".format(int(q * 100)): combined[Ellipsis, i * self.output_size : (i + 1) * self.output_size]
+ for i, q in enumerate(self.quantiles)
+ }
+
+ if return_targets:
+ # Add targets if relevant
+ process_map["targets"] = outputs
+
+ return {k: format_outputs(process_map[k]) for k in process_map}
+
+ def get_attention(self, df):
+ """Computes TFT attention weights for a given dataset.
+
+ Args:
+ df: Input dataframe
+
+ Returns:
+ Dictionary of numpy arrays for temporal attention weights and variable
+ selection weights, along with their identifiers and time indices
+ """
+
+ data = self._batch_data(df)
+ inputs = data["inputs"]
+ identifiers = data["identifier"]
+ time = data["time"]
+
+ def get_batch_attention_weights(input_batch):
+ """Returns weights for a given minibatch of data."""
+ input_placeholder = self._input_placeholder
+ attention_weights = {}
+ for k in self._attention_components:
+ attention_weight = tf.keras.backend.get_session().run(
+ self._attention_components[k], {input_placeholder: input_batch.astype(np.float32)}
+ )
+ attention_weights[k] = attention_weight
+ return attention_weights
+
+ # Compute number of batches
+ batch_size = self.minibatch_size
+ n = inputs.shape[0]
+ num_batches = n // batch_size
+ if n - (num_batches * batch_size) > 0:
+ num_batches += 1
+
+ # Split up inputs into batches
+ batched_inputs = [inputs[i * batch_size : (i + 1) * batch_size, Ellipsis] for i in range(num_batches)]
+
+ # Get attention weights, while avoiding large memory increases
+ attention_by_batch = [get_batch_attention_weights(batch) for batch in batched_inputs]
+ attention_weights = {}
+ for k in self._attention_components:
+ attention_weights[k] = []
+ for batch_weights in attention_by_batch:
+ attention_weights[k].append(batch_weights[k])
+
+ if len(attention_weights[k][0].shape) == 4:
+ tmp = np.concatenate(attention_weights[k], axis=1)
+ else:
+ tmp = np.concatenate(attention_weights[k], axis=0)
+
+ del attention_weights[k]
+ gc.collect()
+ attention_weights[k] = tmp
+
+ attention_weights["identifiers"] = identifiers[:, 0, 0]
+ attention_weights["time"] = time[:, :, 0]
+
+ return attention_weights
+
+ # Serialisation.
+ def reset_temp_folder(self):
+ """Deletes and recreates folder with temporary Keras training outputs."""
+ print("Resetting temp folder...")
+ utils.create_folder_if_not_exist(self._temp_folder)
+ shutil.rmtree(self._temp_folder)
+ os.makedirs(self._temp_folder)
+
+ def get_keras_saved_path(self, model_folder):
+ """Returns path to keras checkpoint."""
+ return os.path.join(model_folder, "{}.check".format(self.name))
+
+ def save(self, model_folder):
+ """Saves optimal TFT weights.
+
+ Args:
+ model_folder: Location to serialze model.
+ """
+ # Allows for direct serialisation of tensorflow variables to avoid spurious
+ # issue with Keras that leads to different performance evaluation results
+ # when model is reloaded (https://github.com/keras-team/keras/issues/4875).
+
+ utils.save(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name)
+
+ def load(self, model_folder, use_keras_loadings=False):
+ """Loads TFT weights.
+
+ Args:
+ model_folder: Folder containing serialized models.
+ use_keras_loadings: Whether to load from Keras checkpoint.
+
+ Returns:
+
+ """
+ if use_keras_loadings:
+ # Loads temporary Keras model saved during training.
+ serialisation_path = self.get_keras_saved_path(model_folder)
+ print("Loading model from {}".format(serialisation_path))
+ self.model.load_weights(serialisation_path)
+ else:
+ # Loads tensorflow graph for optimal models.
+ utils.load(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name)
+
+ @classmethod
+ def get_hyperparm_choices(cls):
+ """Returns hyperparameter ranges for random search."""
+ return {
+ "dropout_rate": [0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9],
+ "hidden_layer_size": [10, 20, 40, 80, 160, 240, 320],
+ "minibatch_size": [64, 128, 256],
+ "learning_rate": [1e-4, 1e-3, 1e-2],
+ "max_gradient_norm": [0.01, 1.0, 100.0],
+ "num_heads": [1, 4],
+ "stack_size": [1],
+ }
diff --git a/examples/benchmarks/TFT/libs/utils.py b/examples/benchmarks/TFT/libs/utils.py
new file mode 100644
index 000000000..4682434d6
--- /dev/null
+++ b/examples/benchmarks/TFT/libs/utils.py
@@ -0,0 +1,224 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Generic helper functions used across codebase."""
+
+import os
+import pathlib
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
+
+
+# Generic.
+def get_single_col_by_input_type(input_type, column_definition):
+ """Returns name of single column.
+
+ Args:
+ input_type: Input type of column to extract
+ column_definition: Column definition list for experiment
+ """
+
+ l = [tup[0] for tup in column_definition if tup[2] == input_type]
+
+ if len(l) != 1:
+ raise ValueError("Invalid number of columns for {}".format(input_type))
+
+ return l[0]
+
+
+def extract_cols_from_data_type(data_type, column_definition, excluded_input_types):
+ """Extracts the names of columns that correspond to a define data_type.
+
+ Args:
+ data_type: DataType of columns to extract.
+ column_definition: Column definition to use.
+ excluded_input_types: Set of input types to exclude
+
+ Returns:
+ List of names for columns with data type specified.
+ """
+ return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types]
+
+
+# Loss functions.
+def tensorflow_quantile_loss(y, y_pred, quantile):
+ """Computes quantile loss for tensorflow.
+
+ Standard quantile loss as defined in the "Training Procedure" section of
+ the main TFT paper
+
+ Args:
+ y: Targets
+ y_pred: Predictions
+ quantile: Quantile to use for loss calculations (between 0 & 1)
+
+ Returns:
+ Tensor for quantile loss.
+ """
+
+ # Checks quantile
+ if quantile < 0 or quantile > 1:
+ raise ValueError("Illegal quantile value={}! Values should be between 0 and 1.".format(quantile))
+
+ prediction_underflow = y - y_pred
+ q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum(
+ -prediction_underflow, 0.0
+ )
+
+ return tf.reduce_sum(q_loss, axis=-1)
+
+
+def numpy_normalised_quantile_loss(y, y_pred, quantile):
+ """Computes normalised quantile loss for numpy arrays.
+
+ Uses the q-Risk metric as defined in the "Training Procedure" section of the
+ main TFT paper.
+
+ Args:
+ y: Targets
+ y_pred: Predictions
+ quantile: Quantile to use for loss calculations (between 0 & 1)
+
+ Returns:
+ Float for normalised quantile loss.
+ """
+ prediction_underflow = y - y_pred
+ weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum(
+ -prediction_underflow, 0.0
+ )
+
+ quantile_loss = weighted_errors.mean()
+ normaliser = y.abs().mean()
+
+ return 2 * quantile_loss / normaliser
+
+
+# OS related functions.
+def create_folder_if_not_exist(directory):
+ """Creates folder if it doesn't exist.
+
+ Args:
+ directory: Folder path to create.
+ """
+ # Also creates directories recursively
+ pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
+
+
+# Tensorflow related functions.
+def get_default_tensorflow_config(tf_device="gpu", gpu_id=0):
+ """Creates tensorflow config for graphs to run on CPU or GPU.
+
+ Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi
+ GPU machines.
+
+ Args:
+ tf_device: 'cpu' or 'gpu'
+ gpu_id: GPU ID to use if relevant
+
+ Returns:
+ Tensorflow config.
+ """
+
+ if tf_device == "cpu":
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # for training on cpu
+ tf_config = tf.ConfigProto(log_device_placement=False, device_count={"GPU": 0})
+
+ else:
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
+
+ print("Selecting GPU ID={}".format(gpu_id))
+
+ tf_config = tf.ConfigProto(log_device_placement=False)
+ tf_config.gpu_options.allow_growth = True
+
+ return tf_config
+
+
+def save(tf_session, model_folder, cp_name, scope=None):
+ """Saves Tensorflow graph to checkpoint.
+
+ Saves all trainiable variables under a given variable scope to checkpoint.
+
+ Args:
+ tf_session: Session containing graph
+ model_folder: Folder to save models
+ cp_name: Name of Tensorflow checkpoint
+ scope: Variable scope containing variables to save
+ """
+ # Save model
+ if scope is None:
+ saver = tf.train.Saver()
+ else:
+ var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
+ saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
+
+ save_path = saver.save(tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name)))
+ print("Model saved to: {0}".format(save_path))
+
+
+def load(tf_session, model_folder, cp_name, scope=None, verbose=False):
+ """Loads Tensorflow graph from checkpoint.
+
+ Args:
+ tf_session: Session to load graph into
+ model_folder: Folder containing serialised model
+ cp_name: Name of Tensorflow checkpoint
+ scope: Variable scope to use.
+ verbose: Whether to print additional debugging information.
+ """
+ # Load model proper
+ load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
+
+ print("Loading model from {0}".format(load_path))
+
+ print_weights_in_checkpoint(model_folder, cp_name)
+
+ initial_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
+
+ # Saver
+ if scope is None:
+ saver = tf.train.Saver()
+ else:
+ var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
+ saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
+ # Load
+ saver.restore(tf_session, load_path)
+ all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
+
+ if verbose:
+ print("Restored {0}".format(",".join(initial_vars.difference(all_vars))))
+ print("Existing {0}".format(",".join(all_vars.difference(initial_vars))))
+ print("All {0}".format(",".join(all_vars)))
+
+ print("Done.")
+
+
+def print_weights_in_checkpoint(model_folder, cp_name):
+ """Prints all weights in Tensorflow checkpoint.
+
+ Args:
+ model_folder: Folder containing checkpoint
+ cp_name: Name of checkpoint
+
+ Returns:
+
+ """
+ load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
+
+ print_tensors_in_checkpoint_file(file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True)
diff --git a/examples/benchmarks/TFT/requirements.txt b/examples/benchmarks/TFT/requirements.txt
new file mode 100644
index 000000000..04234aaed
--- /dev/null
+++ b/examples/benchmarks/TFT/requirements.txt
@@ -0,0 +1,3 @@
+tensorflow-gpu==1.15.0
+numpy == 1.19.4
+pandas==1.1.0
\ No newline at end of file
diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py
new file mode 100644
index 000000000..388ec7f14
--- /dev/null
+++ b/examples/benchmarks/TFT/tft.py
@@ -0,0 +1,249 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import numpy as np
+import pandas as pd
+import tensorflow.compat.v1 as tf
+import data_formatters.base
+import expt_settings.configs
+import libs.hyperparam_opt
+import libs.tft_model
+import libs.utils as utils
+import os
+import datetime as dte
+
+
+from qlib.model.base import ModelFT
+from qlib.data.dataset import DatasetH
+from qlib.data.dataset.handler import DataHandlerLP
+
+
+# To register new datasets, please add them here.
+ALLOW_DATASET = ["Alpha158"]
+DATASET_SETTING = {
+ "Alpha158": {
+ "feature_col": ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "ROC60", "RESI10"],
+ "label_col": ["LABEL0"],
+ },
+}
+# To register new datasets, please add their configurations here.
+
+
+def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
+ return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
+
+
+def fill_test_na(test_df):
+ test_df_res = test_df.copy()
+ feature_cols = ~test_df_res.columns.str.contains("label", case=False)
+ test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
+ test_df_res.loc[:, feature_cols] = test_feature_fna
+ return test_df_res
+
+
+def process_qlib_data(df, dataset, fillna=False):
+ """Prepare data to fit the TFT model.
+
+ Args:
+ df: Original DataFrame.
+ fillna: Whether to fill the data with the mean values.
+
+ Returns:
+ Transformed DataFrame.
+
+ """
+ # Several features selected manually
+ feature_col = DATASET_SETTING[dataset]["feature_col"]
+ label_col = DATASET_SETTING[dataset]["label_col"]
+ temp_df = df.loc[:, feature_col + label_col]
+ if fillna:
+ temp_df = fill_test_na(temp_df)
+ temp_df = temp_df.swaplevel()
+ temp_df = temp_df.sort_index()
+ temp_df = temp_df.reset_index(level=0)
+ dates = pd.to_datetime(temp_df.index)
+ temp_df["date"] = dates
+ temp_df["day_of_week"] = dates.dayofweek
+ temp_df["month"] = dates.month
+ temp_df["year"] = dates.year
+ temp_df["const"] = 1.0
+ return temp_df
+
+
+def process_predicted(df, col_name):
+ """Transform the TFT predicted data into Qlib format.
+
+ Args:
+ df: Original DataFrame.
+ fillna: New column name.
+
+ Returns:
+ Transformed DataFrame.
+
+ """
+ df_res = df.copy()
+ df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
+ df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
+ df_res = df_res[[col_name]]
+ return df_res
+
+
+def format_score(forecast_df, col_name="pred", label_shift=5):
+ pred = process_predicted(forecast_df, col_name=col_name)
+ pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
+ pred = pred.dropna()[col_name]
+ return pred
+
+
+def transform_df(df, col_name="LABEL0"):
+ df_res = df["feature"]
+ df_res[col_name] = df["label"]
+ return df_res
+
+
+class TFTModel(ModelFT):
+ """TFT Model"""
+
+ def __init__(self, **kwargs):
+ self.model = None
+
+ def _prepare_data(self, dataset: DatasetH):
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+ return transform_df(df_train), transform_df(df_valid)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ DATASET="Alpha158",
+ MODEL_FOLDER="qlib_alpha158_model",
+ LABEL_COL="LABEL0",
+ LABEL_SHIFT=5,
+ USE_GPU_ID=0,
+ **kwargs
+ ):
+
+ if DATASET not in ALLOW_DATASET:
+ raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
+
+ dtrain, dvalid = self._prepare_data(dataset)
+ dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
+ dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
+
+ train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
+ valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
+
+ ExperimentConfig = expt_settings.configs.ExperimentConfig
+ config = ExperimentConfig(DATASET)
+ self.data_formatter = config.make_data_formatter()
+ self.model_folder = MODEL_FOLDER
+ self.gpu_id = USE_GPU_ID
+ self.label_shift = LABEL_SHIFT
+ self.expt_name = DATASET
+ self.label_col = LABEL_COL
+
+ use_gpu = (True, self.gpu_id)
+ # ===========================Training Process===========================
+ ModelClass = libs.tft_model.TemporalFusionTransformer
+ if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
+ raise ValueError(
+ "Data formatters should inherit from"
+ + "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
+ )
+
+ default_keras_session = tf.keras.backend.get_session()
+
+ if use_gpu[0]:
+ self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
+ else:
+ self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
+
+ self.data_formatter.set_scalers(train)
+
+ # Sets up default params
+ fixed_params = self.data_formatter.get_experiment_params()
+ params = self.data_formatter.get_default_model_params()
+
+ # Wendi: 合并调优的参数和非调优的参数
+ params = {**params, **fixed_params}
+
+ if not os.path.exists(self.model_folder):
+ os.makedirs(self.model_folder)
+ params["model_folder"] = self.model_folder
+
+ print("*** Begin training ***")
+ best_loss = np.Inf
+
+ tf.reset_default_graph()
+
+ self.tf_graph = tf.Graph()
+ with self.tf_graph.as_default():
+ self.sess = tf.Session(config=self.tf_config)
+ tf.keras.backend.set_session(self.sess)
+ self.model = ModelClass(params, use_cudnn=use_gpu[0])
+ self.sess.run(tf.global_variables_initializer())
+ self.model.fit(train_df=train, valid_df=valid)
+ print("*** Finished training ***")
+ saved_model_dir = self.model_folder + "/" + "saved_model"
+ if not os.path.exists(saved_model_dir):
+ os.makedirs(saved_model_dir)
+ self.model.save(saved_model_dir)
+
+ def extract_numerical_data(data):
+ """Strips out forecast time and identifier columns."""
+ return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
+
+ # p50_loss = utils.numpy_normalised_quantile_loss(
+ # extract_numerical_data(targets), extract_numerical_data(p50_forecast),
+ # 0.5)
+ # p90_loss = utils.numpy_normalised_quantile_loss(
+ # extract_numerical_data(targets), extract_numerical_data(p90_forecast),
+ # 0.9)
+ tf.keras.backend.set_session(default_keras_session)
+ print("Training completed.".format(dte.datetime.now()))
+ # ===========================Training Process===========================
+
+ def predict(self, dataset):
+ if self.model is None:
+ raise ValueError("model is not fitted yet!")
+ d_test = dataset.prepare("test", col_set=["feature", "label"])
+ d_test = transform_df(d_test)
+ d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
+ test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
+
+ use_gpu = (True, self.gpu_id)
+ # ===========================Predicting Process===========================
+ default_keras_session = tf.keras.backend.get_session()
+
+ # Sets up default params
+ fixed_params = self.data_formatter.get_experiment_params()
+ params = self.data_formatter.get_default_model_params()
+ params = {**params, **fixed_params}
+
+ print("*** Begin predicting ***")
+ tf.reset_default_graph()
+
+ with self.tf_graph.as_default():
+ tf.keras.backend.set_session(self.sess)
+ output_map = self.model.predict(test, return_targets=True)
+ targets = self.data_formatter.format_predictions(output_map["targets"])
+ p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
+ p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
+ tf.keras.backend.set_session(default_keras_session)
+
+ predict50 = format_score(p50_forecast, "pred", 1)
+ predict90 = format_score(p90_forecast, "pred", 1)
+ predict = (predict50 + predict90) / 2 # self.label_shift
+ # ===========================Predicting Process===========================
+ return predict
+
+ def finetune(self, dataset: DatasetH):
+ """
+ finetune model
+ Parameters
+ ----------
+ dataset : DatasetH
+ dataset for finetuning
+ """
+ pass
diff --git a/examples/benchmarks/TFT/workflow_config_tft.yaml b/examples/benchmarks/TFT/workflow_config_tft.yaml
new file mode 100644
index 000000000..d8ee14e71
--- /dev/null
+++ b/examples/benchmarks/TFT/workflow_config_tft.yaml
@@ -0,0 +1,52 @@
+sys:
+ rel_path: .
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: TFTModel
+ module_path: tft
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/XGBoost/README.md b/examples/benchmarks/XGBoost/README.md
new file mode 100644
index 000000000..33e04b23b
--- /dev/null
+++ b/examples/benchmarks/XGBoost/README.md
@@ -0,0 +1,3 @@
+# XGBoost
+* Code: [https://github.com/dmlc/xgboost](https://github.com/dmlc/xgboost)
+* Paper: XGBoost: A Scalable Tree Boosting System. [https://dl.acm.org/doi/pdf/10.1145/2939672.2939785](https://dl.acm.org/doi/pdf/10.1145/2939672.2939785).
\ No newline at end of file
diff --git a/examples/benchmarks/XGBoost/requirements.txt b/examples/benchmarks/XGBoost/requirements.txt
new file mode 100644
index 000000000..077f343e5
--- /dev/null
+++ b/examples/benchmarks/XGBoost/requirements.txt
@@ -0,0 +1,3 @@
+numpy==1.17.4
+pandas==1.1.2
+xgboost==1.2.1
\ No newline at end of file
diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
new file mode 100644
index 000000000..1352c496d
--- /dev/null
+++ b/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
@@ -0,0 +1,63 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: XGBModel
+ module_path: qlib.contrib.model.xgboost
+ kwargs:
+ eval_metric: rmse
+ colsample_bytree: 0.8879
+ eta: 0.0421
+ max_depth: 8
+ n_estimators: 647
+ subsample: 0.8789
+ nthread: 20
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/estimator/analyze_from_estimator.ipynb b/examples/estimator/analyze_from_estimator.ipynb
deleted file mode 100644
index 2ed63bf22..000000000
--- a/examples/estimator/analyze_from_estimator.ipynb
+++ /dev/null
@@ -1,222 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import sys\n",
- "import json\n",
- "import yaml\n",
- "import pickle\n",
- "from pathlib import Path\n",
- "\n",
- "import qlib\n",
- "import pandas as pd\n",
- "from qlib.config import REG_CN\n",
- "from qlib.utils import exists_qlib_data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "CUR_DIR = Path.cwd()\n",
- "MARKET = \"csi300\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# use default data\n",
- "# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data\n",
- "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
- "if not exists_qlib_data(provider_uri):\n",
- " print(f\"Qlib data is not found in {provider_uri}\")\n",
- " sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n",
- " from get_data import GetData\n",
- " GetData().qlib_data(target_dir=provider_uri)\n",
- "qlib.init(provider_uri=provider_uri, region=REG_CN)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "with CUR_DIR.joinpath('estimator_config.yaml').open() as fp:\n",
- " estimator_name = yaml.load(fp, Loader=yaml.FullLoader)['experiment']['name']\n",
- "with CUR_DIR.joinpath(estimator_name, 'exp_info.json').open() as fp:\n",
- " latest_id = json.load(fp)['id']\n",
- " \n",
- "estimator_dir = CUR_DIR.joinpath(estimator_name, 'sacred', latest_id)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# read estimator result"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pred_df = pd.read_pickle(estimator_dir.joinpath('pred.pkl'))\n",
- "report_normal_df = pd.read_pickle(estimator_dir.joinpath('report_normal.pkl'))\n",
- "report_normal_df.index.names = ['index']\n",
- "\n",
- "analysis_df = pd.read_pickle(estimator_dir.joinpath('analysis.pkl'))\n",
- "positions = pickle.load(estimator_dir.joinpath('positions.pkl').open('rb'))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# analyze graphs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from qlib.data import D\n",
- "from qlib.contrib.report import analysis_model, analysis_position\n",
- "pred_df_dates = pred_df.index.get_level_values(level='datetime')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## analysis position"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "stock_ret = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
- "stock_ret.columns = ['label']"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### report"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis_position.report_graph(report_normal_df)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### risk analysis"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## analysis model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "label_df = D.features(D.instruments(MARKET), ['Ref($close, -2)/Ref($close, -1) - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
- "label_df.columns = ['label']"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### score IC"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
- "analysis_position.score_ic_graph(pred_label)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### model performance"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis_model.model_performance_graph(pred_label)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.5"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
\ No newline at end of file
diff --git a/examples/estimator/estimator_config.yaml b/examples/estimator/estimator_config.yaml
deleted file mode 100644
index 7b532ca40..000000000
--- a/examples/estimator/estimator_config.yaml
+++ /dev/null
@@ -1,53 +0,0 @@
-experiment:
- name: estimator_example
- observer_type: file_storage
- mode: train
-
-model:
- class: LGBModel
- module_path: qlib.contrib.model.gbdt
- args:
- loss: mse
- colsample_bytree: 0.8879
- learning_rate: 0.0421
- subsample: 0.8789
- lambda_l1: 205.6999
- lambda_l2: 580.9768
- max_depth: 8
- num_leaves: 210
- num_threads: 20
-data:
- class: Alpha158
- args:
- dropna_label: True
- filter:
- market: csi300
-trainer:
- class: StaticTrainer
- args:
- train_start_date: 2008-01-01
- train_end_date: 2014-12-31
- validate_start_date: 2015-01-01
- validate_end_date: 2016-12-31
- test_start_date: 2017-01-01
- test_end_date: 2020-08-01
-strategy:
- class: TopkDropoutStrategy
- args:
- topk: 50
- n_drop: 5
-backtest:
- normal_backtest_args:
- verbose: False
- limit_threshold: 0.095
- account: 100000000
- benchmark: SH000300
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
-
-qlib_data:
- # when testing, please modify the following parameters according to the specific environment
- provider_uri: "~/.qlib/qlib_data/cn_data"
- region: "cn"
diff --git a/examples/estimator/estimator_config_dnn.yaml b/examples/estimator/estimator_config_dnn.yaml
deleted file mode 100644
index a4a9d18ff..000000000
--- a/examples/estimator/estimator_config_dnn.yaml
+++ /dev/null
@@ -1,55 +0,0 @@
-experiment:
- name: estimator_example
- observer_type: file_storage
- mode: train
-
-model:
- module_path: qlib.contrib.model.pytorch_nn
- class: DNNModelPytorch
- args:
- loss: mse
- input_dim: 158
- output_dim: 1
- lr: 0.002
- lr_decay: 0.96
- lr_decay_steps: 100
- optimizer: 'adam'
- max_steps: 8000
- batch_size: 4096
- GPU: '0'
-data:
- class: Alpha158
- args:
- dropna_label: True
- dropna_feature: True
- filter:
- market: csi300
-trainer:
- class: StaticTrainer
- args:
- train_start_date: 2007-01-01
- train_end_date: 2014-12-31
- validate_start_date: 2015-01-01
- validate_end_date: 2016-12-31
- test_start_date: 2017-01-01
- test_end_date: 2020-08-01
-strategy:
- class: TopkDropoutStrategy
- args:
- topk: 50
- n_drop: 5
-backtest:
- normal_backtest_args:
- verbose: False
- limit_threshold: 0.095
- account: 100000000
- benchmark: SH000300
- deal_price: close
- open_cost: 0.0005
- close_cost: 0.0015
- min_cost: 5
-
-qlib_data:
- # when testing, please modify the following parameters according to the specific environment
- provider_uri: "~/.qlib/qlib_data/cn_data"
- region: "cn"
diff --git a/examples/run_all_model.py b/examples/run_all_model.py
new file mode 100644
index 000000000..8843573ab
--- /dev/null
+++ b/examples/run_all_model.py
@@ -0,0 +1,284 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import os
+import sys
+import fire
+import time
+import venv
+import glob
+import shutil
+import signal
+import inspect
+import tempfile
+import traceback
+import functools
+import statistics
+import subprocess
+from pathlib import Path
+from operator import xor
+from pprint import pprint
+
+import qlib
+from qlib.config import REG_CN
+from qlib.workflow import R
+from qlib.workflow.cli import workflow
+from qlib.utils import exists_qlib_data
+
+
+# init qlib
+provider_uri = "~/.qlib/qlib_data/cn_data"
+exp_folder_name = "run_all_model_records"
+exp_path = str(Path(os.getcwd()).resolve() / exp_folder_name)
+exp_manager = {
+ "class": "MLflowExpManager",
+ "module_path": "qlib.workflow.expm",
+ "kwargs": {
+ "uri": "file:" + exp_path,
+ "default_exp_name": "Experiment",
+ },
+}
+if not exists_qlib_data(provider_uri):
+ print(f"Qlib data is not found in {provider_uri}")
+ sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
+ from get_data import GetData
+
+ GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
+qlib.init(provider_uri=provider_uri, region=REG_CN, exp_manager=exp_manager)
+if os.path.isdir(exp_path):
+ shutil.rmtree(exp_path)
+
+# decorator to check the arguments
+def only_allow_defined_args(function_to_decorate):
+ @functools.wraps(function_to_decorate)
+ def _return_wrapped(*args, **kwargs):
+ """Internal wrapper function."""
+ argspec = inspect.getfullargspec(function_to_decorate)
+ valid_names = set(argspec.args + argspec.kwonlyargs)
+ if "self" in valid_names:
+ valid_names.remove("self")
+ for arg_name in kwargs:
+ if arg_name not in valid_names:
+ raise ValueError("Unknown argument seen '%s', expected: [%s]" % (arg_name, ", ".join(valid_names)))
+ return function_to_decorate(*args, **kwargs)
+
+ return _return_wrapped
+
+
+# function to handle ctrl z and ctrl c
+def handler(signum, frame):
+ os.system("kill -9 %d" % os.getpid())
+
+
+signal.signal(signal.SIGTSTP, handler)
+signal.signal(signal.SIGINT, handler)
+
+# function to calculate the mean and std of a list in the results dictionary
+def cal_mean_std(results) -> dict:
+ mean_std = dict()
+ for fn in results:
+ mean_std[fn] = dict()
+ for metric in results[fn]:
+ mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0]
+ std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0
+ mean_std[fn][metric] = [mean, std]
+ return mean_std
+
+
+# function to create the environment ofr an anaconda environment
+def create_env():
+ # create env
+ temp_dir = tempfile.mkdtemp()
+ env_path = Path(temp_dir).absolute()
+ sys.stderr.write(f"Creating Virtual Environment with path: {env_path}...\n")
+ execute(f"conda create --prefix {env_path} python=3.7 -y")
+ python_path = env_path / "bin" / "python" # TODO: FIX ME!
+ sys.stderr.write("\n")
+ # get anaconda activate path
+ conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME!
+ return env_path, python_path, conda_activate
+
+
+# function to execute the cmd
+def execute(cmd):
+ with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p:
+ for line in p.stdout:
+ sys.stdout.write(line.split("\b")[0])
+ if "\b" in line:
+ sys.stdout.flush()
+ time.sleep(0.1)
+ sys.stdout.write("\b" * 10 + "\b".join(line.split("\b")[1:-1]))
+
+ if p.returncode != 0:
+ return p.stderr
+ else:
+ return None
+
+
+# function to get all the folders benchmark folder
+def get_all_folders(models, exclude) -> dict:
+ folders = dict()
+ if isinstance(models, str):
+ model_list = models.split(",")
+ models = [m.lower().strip("[ ]") for m in model_list]
+ elif isinstance(models, list):
+ models = [m.lower() for m in models]
+ elif models is None:
+ models = [f.name.lower() for f in os.scandir("benchmarks")]
+ else:
+ raise ValueError("Input models type is not supported. Please provide str or list without space.")
+ for f in os.scandir("benchmarks"):
+ add = xor(bool(f.name.lower() in models), bool(exclude))
+ if add:
+ path = Path("benchmarks") / f.name
+ folders[f.name] = str(path.resolve())
+ return folders
+
+
+# function to get all the files under the model folder
+def get_all_files(folder_path) -> (str, str):
+ yaml_path = str(Path(f"{folder_path}") / "*.yaml")
+ req_path = str(Path(f"{folder_path}") / "*.txt")
+ return glob.glob(yaml_path)[0], glob.glob(req_path)[0]
+
+
+# function to retrieve all the results
+def get_all_results(folders) -> dict:
+ results = dict()
+ for fn in folders:
+ exp = R.get_exp(experiment_name=fn, create=False)
+ recorders = exp.list_recorders()
+ result = dict()
+ result["annualized_return_with_cost"] = list()
+ result["information_ratio_with_cost"] = list()
+ result["max_drawdown_with_cost"] = list()
+ for recorder_id in recorders:
+ if recorders[recorder_id].status == "FINISHED":
+ recorder = R.get_recorder(recorder_id=recorder_id, experiment_name=fn)
+ metrics = recorder.list_metrics()
+ result["annualized_return_with_cost"].append(metrics["excess_return_with_cost.annualized_return"])
+ result["information_ratio_with_cost"].append(metrics["excess_return_with_cost.information_ratio"])
+ result["max_drawdown_with_cost"].append(metrics["excess_return_with_cost.max_drawdown"])
+ results[fn] = result
+ return results
+
+
+# function to generate and save markdown table
+def gen_and_save_md_table(metrics):
+ table = "| Model Name | Annualized Return | Information Ratio | Max Drawdown |\n"
+ table += "|---|---|---|---|\n"
+ for fn in metrics:
+ ar = metrics[fn]["annualized_return_with_cost"]
+ ir = metrics[fn]["information_ratio_with_cost"]
+ md = metrics[fn]["max_drawdown_with_cost"]
+ table += f"| {fn} | {ar[0]:9.4f}±{ar[1]:9.2f} | {ir[0]:9.4f}±{ir[1]:9.2f}| {md[0]:9.4f}±{md[1]:9.2f} |\n"
+ pprint(table)
+ with open("table.md", "w") as f:
+ f.write(table)
+ return table
+
+
+# function to run the all the models
+@only_allow_defined_args
+def run(times=1, models=None, exclude=False):
+ """
+ Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
+ Any PR to enhance this method is highly welcomed.
+
+ Parameters:
+ -----------
+ times : int
+ determines how many times the model should be running.
+ models : str or list
+ determines the specific model or list of models to run or exclude.
+ exclude : boolean
+ determines whether the model being used is excluded or included.
+
+ Usage:
+ -------
+ Here are some use cases of the function in the bash:
+
+ .. code-block:: bash
+
+ # Case 1 - run all models multiple times
+ python run_all_model.py 3
+
+ # Case 2 - run specific models multiple times
+ python run_all_model.py 3 mlp
+
+ # Case 3 - run other models except those are given as arguments for multiple times
+ python run_all_model.py 3 [mlp,tft,lstm] True
+
+ # Case 4 - run specific models for one time
+ python run_all_model.py --models=[mlp,lightgbm]
+
+ # Case 5 - run other models except those are given as aruments for one time
+ python run_all_model.py --models=[mlp,tft,sfm] --exclude=True
+
+ """
+ # get all folders
+ folders = get_all_folders(models, exclude)
+ # init error messages:
+ errors = dict()
+ # run all the model for iterations
+ for fn in folders:
+ # create env by anaconda
+ env_path, python_path, conda_activate = create_env()
+ # get all files
+ sys.stderr.write("Retrieving files...\n")
+ yaml_path, req_path = get_all_files(folders[fn])
+ sys.stderr.write("\n")
+ # install requirements.txt
+ sys.stderr.write("Installing requirements.txt...\n")
+ execute(f"{python_path} -m pip install -r {req_path}")
+ sys.stderr.write("\n")
+ # setup gpu for tft
+ if fn == "TFT":
+ execute(
+ f"conda install -y --prefix {env_path} anaconda cudatoolkit=10.0 && conda install -y --prefix {env_path} cudnn"
+ )
+ sys.stderr.write("\n")
+ # install qlib
+ sys.stderr.write("Installing qlib...\n")
+ execute(f"{python_path} -m pip install --upgrade cython") # TODO: FIX ME!
+ if fn == "TFT":
+ execute(
+ f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
+ ) # TODO: FIX ME!
+ else:
+ execute(
+ f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall -e git+https://github.com/you-n-g/qlib#egg=pyqlib"
+ ) # TODO: FIX ME!
+ sys.stderr.write("\n")
+ # run workflow_by_config for multiple times
+ for i in range(times):
+ sys.stderr.write(f"Running the model: {fn} for iteration {i+1}...\n")
+ errs = execute(
+ f"{python_path} {env_path / 'src/pyqlib/qlib/workflow/cli.py'} {yaml_path} {fn} {exp_folder_name}"
+ )
+ if errs is not None:
+ _errs = errors.get(fn, {})
+ _errs.update({i: errs})
+ errors[fn] = _errs
+ sys.stderr.write("\n")
+ # remove env
+ sys.stderr.write(f"Deleting the environment: {env_path}...\n")
+ shutil.rmtree(env_path)
+ # getting all results
+ sys.stderr.write(f"Retrieving results...\n")
+ results = get_all_results(folders)
+ # calculating the mean and std
+ sys.stderr.write(f"Calculating the mean and std of results...\n")
+ results = cal_mean_std(results)
+ # generating md table
+ sys.stderr.write(f"Generating markdown table...\n")
+ gen_and_save_md_table(results)
+ sys.stderr.write("\n")
+ # print erros
+ sys.stderr.write(f"Here are some of the errors of the models...\n")
+ pprint(errors)
+ sys.stderr.write("\n")
+
+
+if __name__ == "__main__":
+ fire.Fire(run) # run all the model
diff --git a/examples/train_and_backtest.py b/examples/train_and_backtest.py
deleted file mode 100644
index 045587f52..000000000
--- a/examples/train_and_backtest.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import sys
-from pathlib import Path
-
-import qlib
-import pandas as pd
-from qlib.config import REG_CN
-from qlib.contrib.model.gbdt import LGBModel
-from qlib.contrib.estimator.handler import Alpha158
-from qlib.contrib.strategy.strategy import TopkDropoutStrategy
-from qlib.contrib.evaluate import (
- backtest as normal_backtest,
- risk_analysis,
-)
-from qlib.utils import exists_qlib_data
-
-
-if __name__ == "__main__":
-
- # use default data
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
- from get_data import GetData
-
- GetData().qlib_data(target_dir=provider_uri)
-
- qlib.init(provider_uri=provider_uri, region=REG_CN)
-
- MARKET = "CSI300"
- BENCHMARK = "SH000300"
-
- ###################################
- # train model
- ###################################
- DATA_HANDLER_CONFIG = {
- "dropna_label": True,
- "start_date": "2008-01-01",
- "end_date": "2020-08-01",
- "market": MARKET,
- }
-
- TRAINER_CONFIG = {
- "train_start_date": "2008-01-01",
- "train_end_date": "2014-12-31",
- "validate_start_date": "2015-01-01",
- "validate_end_date": "2016-12-31",
- "test_start_date": "2017-01-01",
- "test_end_date": "2020-08-01",
- }
-
- # use default DataHandler
- # custom DataHandler, refer to: TODO: DataHandler API url
- x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
- **TRAINER_CONFIG
- )
-
- MODEL_CONFIG = {
- "loss": "mse",
- "colsample_bytree": 0.8879,
- "learning_rate": 0.0421,
- "subsample": 0.8789,
- "lambda_l1": 205.6999,
- "lambda_l2": 580.9768,
- "max_depth": 8,
- "num_leaves": 210,
- "num_threads": 20,
- }
- # use default model
- # custom Model, refer to: TODO: Model API url
- model = LGBModel(**MODEL_CONFIG)
- model.fit(x_train, y_train, x_validate, y_validate)
- _pred = model.predict(x_test)
- _pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
-
- # backtest requires pred_score
- pred_score = pd.DataFrame(index=_pred.index)
- pred_score["score"] = _pred.iloc(axis=1)[0]
-
- # save pred_score to file
- pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
- pred_score_path.parent.mkdir(exist_ok=True, parents=True)
- pred_score.to_pickle(pred_score_path)
-
- ###################################
- # backtest
- ###################################
- STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
- }
- BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": BENCHMARK,
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
- }
-
- # use default strategy
- # custom Strategy, refer to: TODO: Strategy API url
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
-
- ###################################
- # analyze
- # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
- ###################################
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- print(analysis_df)
diff --git a/examples/train_backtest_analyze.ipynb b/examples/train_backtest_analyze.ipynb
deleted file mode 100644
index d8987b58f..000000000
--- a/examples/train_backtest_analyze.ipynb
+++ /dev/null
@@ -1,338 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import sys\n",
- "from pathlib import Path\n",
- "\n",
- "import qlib\n",
- "import pandas as pd\n",
- "from qlib.config import REG_CN\n",
- "from qlib.contrib.model.gbdt import LGBModel\n",
- "from qlib.contrib.estimator.handler import Alpha158\n",
- "from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
- "from qlib.contrib.evaluate import (\n",
- " backtest as normal_backtest,\n",
- " risk_analysis,\n",
- ")\n",
- "from qlib.utils import exists_qlib_data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "# use default data\n",
- "# NOTE: need to download data from remote: python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn\n",
- "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
- "if not exists_qlib_data(provider_uri):\n",
- " print(f\"Qlib data is not found in {provider_uri}\")\n",
- " sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
- " from get_data import GetData\n",
- " GetData().qlib_data(target_dir=provider_uri)\n",
- "qlib.init(provider_uri=provider_uri, region=REG_CN)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "MARKET = \"csi300\"\n",
- "BENCHMARK = \"SH000300\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# train model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "###################################\n",
- "# train model\n",
- "###################################\n",
- "DATA_HANDLER_CONFIG = {\n",
- " \"dropna_label\": True,\n",
- " \"start_date\": \"2008-01-01\",\n",
- " \"end_date\": \"2020-08-01\",\n",
- " \"market\": MARKET,\n",
- "}\n",
- "\n",
- "TRAINER_CONFIG = {\n",
- " \"train_start_date\": \"2008-01-01\",\n",
- " \"train_end_date\": \"2014-12-31\",\n",
- " \"validate_start_date\": \"2015-01-01\",\n",
- " \"validate_end_date\": \"2016-12-31\",\n",
- " \"test_start_date\": \"2017-01-01\",\n",
- " \"test_end_date\": \"2020-08-01\",\n",
- "}\n",
- "\n",
- "# use default DataHandler\n",
- "# custom DataHandler, refer to: TODO: DataHandler api url\n",
- "x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(**TRAINER_CONFIG)\n",
- "\n",
- "\n",
- "MODEL_CONFIG = {\n",
- " \"loss\": \"mse\",\n",
- " \"colsample_bytree\": 0.8879,\n",
- " \"learning_rate\": 0.0421,\n",
- " \"subsample\": 0.8789,\n",
- " \"lambda_l1\": 205.6999,\n",
- " \"lambda_l2\": 580.9768,\n",
- " \"max_depth\": 8,\n",
- " \"num_leaves\": 210,\n",
- " \"num_threads\": 20,\n",
- "}\n",
- "# use default model\n",
- "# custom Model, refer to: TODO: Model api url\n",
- "model = LGBModel(**MODEL_CONFIG)\n",
- "model.fit(x_train, y_train, x_validate, y_validate)\n",
- "_pred = model.predict(x_test)\n",
- "_pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)\n",
- "\n",
- "# backtest requires pred_score\n",
- "pred_score = pd.DataFrame(index=_pred.index)\n",
- "pred_score[\"score\"] = _pred.iloc(axis=1)[0]\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# backtest"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "###################################\n",
- "# backtest\n",
- "###################################\n",
- "STRATEGY_CONFIG = {\n",
- " \"topk\": 50,\n",
- " \"n_drop\": 5}\n",
- "BACKTEST_CONFIG = {\n",
- " \"verbose\": False,\n",
- " \"limit_threshold\": 0.095,\n",
- " \"account\": 100000000,\n",
- " \"benchmark\": BENCHMARK,\n",
- " \"deal_price\": \"close\",\n",
- " \"open_cost\": 0.0005,\n",
- " \"close_cost\": 0.0015,\n",
- " \"min_cost\": 5,\n",
- " \n",
- "}\n",
- "\n",
- "# use default strategy\n",
- "# custom Strategy, refer to: TODO: Strategy api url\n",
- "strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)\n",
- "report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# analyze"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "###################################\n",
- "# analyze\n",
- "# If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb\n",
- "###################################\n",
- "analysis = dict()\n",
- "analysis[\"excess_return_without_cost\"] = risk_analysis(report_normal[\"return\"] - report_normal[\"bench\"])\n",
- "analysis[\"excess_return_with_cost\"] = risk_analysis(\n",
- " report_normal[\"return\"] - report_normal[\"bench\"] - report_normal[\"cost\"]\n",
- ")\n",
- "analysis_df = pd.concat(analysis) # type: pd.DataFrame\n",
- "print(analysis_df)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# analyze graphs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from qlib.contrib.report import analysis_model, analysis_position\n",
- "from qlib.data import D\n",
- "pred_df_dates = pred_score.index.get_level_values(level='datetime')\n",
- "report_normal_df = report_normal\n",
- "positions = positions_normal\n",
- "pred_df = pred_score"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## analysis position"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "stock_ret = D.features(D.instruments(MARKET), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
- "stock_ret.columns = ['label']"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### report"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis_position.report_graph(report_normal_df)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### risk analysis"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## analysis model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "label_df = D.features(D.instruments(MARKET), ['Ref($close, -2)/Ref($close, -1) - 1'], pred_df_dates.min(), pred_df_dates.max())\n",
- "label_df.columns = ['label']"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### score IC"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
- "analysis_position.score_ic_graph(pred_label)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### model performance"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "analysis_model.model_performance_graph(pred_label)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
\ No newline at end of file
diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb
new file mode 100644
index 000000000..5a992e339
--- /dev/null
+++ b/examples/workflow_by_code.ipynb
@@ -0,0 +1,380 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Copyright (c) Microsoft Corporation.\n",
+ "# Licensed under the MIT License."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "import sys, site\n",
+ "from pathlib import Path\n",
+ "\n",
+ "\n",
+ "try:\n",
+ " import qlib\n",
+ "except ImportError:\n",
+ " # install qlib\n",
+ " ! pip install pyqlib\n",
+ " # reload\n",
+ " site.main()\n",
+ "\n",
+ "scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n",
+ "if not scripts_dir.joinpath(\"get_data.py\").exists():\n",
+ " # download get_data.py script\n",
+ " scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
+ " scripts_dir.mkdir(parents=True, exist_ok=True)\n",
+ " import requests\n",
+ " with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\") as resp:\n",
+ " with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
+ " fp.write(resp.content)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "import qlib\n",
+ "import pandas as pd\n",
+ "from qlib.config import REG_CN\n",
+ "from qlib.contrib.model.gbdt import LGBModel\n",
+ "from qlib.contrib.data.handler import Alpha158\n",
+ "from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
+ "from qlib.contrib.evaluate import (\n",
+ " backtest as normal_backtest,\n",
+ " risk_analysis,\n",
+ ")\n",
+ "from qlib.utils import exists_qlib_data, init_instance_by_config\n",
+ "from qlib.workflow import R\n",
+ "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
+ "from qlib.utils import flatten_dict\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# use default data\n",
+ "# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
+ "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
+ "if not exists_qlib_data(provider_uri):\n",
+ " print(f\"Qlib data is not found in {provider_uri}\")\n",
+ " sys.path.append(str(scripts_dir))\n",
+ " from get_data import GetData\n",
+ " GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
+ "qlib.init(provider_uri=provider_uri, region=REG_CN)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "market = \"csi300\"\n",
+ "benchmark = \"SH000300\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# train model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "###################################\n",
+ "# train model\n",
+ "###################################\n",
+ "data_handler_config = {\n",
+ " \"start_time\": \"2008-01-01\",\n",
+ " \"end_time\": \"2020-08-01\",\n",
+ " \"fit_start_time\": \"2008-01-01\",\n",
+ " \"fit_end_time\": \"2014-12-31\",\n",
+ " \"instruments\": market,\n",
+ "}\n",
+ "\n",
+ "task = {\n",
+ " \"model\": {\n",
+ " \"class\": \"LGBModel\",\n",
+ " \"module_path\": \"qlib.contrib.model.gbdt\",\n",
+ " \"kwargs\": {\n",
+ " \"loss\": \"mse\",\n",
+ " \"colsample_bytree\": 0.8879,\n",
+ " \"learning_rate\": 0.0421,\n",
+ " \"subsample\": 0.8789,\n",
+ " \"lambda_l1\": 205.6999,\n",
+ " \"lambda_l2\": 580.9768,\n",
+ " \"max_depth\": 8,\n",
+ " \"num_leaves\": 210,\n",
+ " \"num_threads\": 20,\n",
+ " },\n",
+ " },\n",
+ " \"dataset\": {\n",
+ " \"class\": \"DatasetH\",\n",
+ " \"module_path\": \"qlib.data.dataset\",\n",
+ " \"kwargs\": {\n",
+ " \"handler\": {\n",
+ " \"class\": \"Alpha158\",\n",
+ " \"module_path\": \"qlib.contrib.data.handler\",\n",
+ " \"kwargs\": data_handler_config,\n",
+ " },\n",
+ " \"segments\": {\n",
+ " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
+ " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
+ " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ "}\n",
+ "\n",
+ "# model initiaiton\n",
+ "model = init_instance_by_config(task[\"model\"])\n",
+ "dataset = init_instance_by_config(task[\"dataset\"])\n",
+ "\n",
+ "# start exp to train model\n",
+ "with R.start(experiment_name=\"train_model\"):\n",
+ " R.log_params(**flatten_dict(task))\n",
+ " model.fit(dataset)\n",
+ " R.save_objects(trained_model=model)\n",
+ " rid = R.get_recorder().id\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# prediction, backtest & analysis"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "###################################\n",
+ "# prediction, backtest & analysis\n",
+ "###################################\n",
+ "port_analysis_config = {\n",
+ " \"strategy\": {\n",
+ " \"class\": \"TopkDropoutStrategy\",\n",
+ " \"module_path\": \"qlib.contrib.strategy.strategy\",\n",
+ " \"kwargs\": {\n",
+ " \"topk\": 50,\n",
+ " \"n_drop\": 5,\n",
+ " },\n",
+ " },\n",
+ " \"backtest\": {\n",
+ " \"verbose\": False,\n",
+ " \"limit_threshold\": 0.095,\n",
+ " \"account\": 100000000,\n",
+ " \"benchmark\": benchmark,\n",
+ " \"deal_price\": \"close\",\n",
+ " \"open_cost\": 0.0005,\n",
+ " \"close_cost\": 0.0015,\n",
+ " \"min_cost\": 5,\n",
+ " },\n",
+ "}\n",
+ "\n",
+ "\n",
+ "# backtest and analysis\n",
+ "with R.start(experiment_name=\"backtest_analysis\"):\n",
+ " recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
+ " model = recorder.load_object(\"trained_model\")\n",
+ "\n",
+ " # prediction\n",
+ " recorder = R.get_recorder()\n",
+ " ba_rid = recorder.id\n",
+ " sr = SignalRecord(model, dataset, recorder)\n",
+ " sr.generate()\n",
+ "\n",
+ " # backtest & analysis\n",
+ " par = PortAnaRecord(recorder, port_analysis_config)\n",
+ " par.generate()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# analyze graphs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "from qlib.contrib.report import analysis_model, analysis_position\n",
+ "from qlib.data import D\n",
+ "recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
+ "pred_df = recorder.load_object(\"pred.pkl\")\n",
+ "pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
+ "report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",
+ "positions = recorder.load_object(\"portfolio_analysis/positions_normal.pkl\")\n",
+ "analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis.pkl\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## analysis position"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### report"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "analysis_position.report_graph(report_normal_df)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### risk analysis"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## analysis model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
+ "label_df.columns = ['label']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### score IC"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
+ "analysis_position.score_ic_graph(pred_label)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### model performance"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "analysis_model.model_performance_graph(pred_label)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.9"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py
new file mode 100644
index 000000000..8fdb4332f
--- /dev/null
+++ b/examples/workflow_by_code.py
@@ -0,0 +1,120 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import sys
+from pathlib import Path
+
+import qlib
+import pandas as pd
+from qlib.config import REG_CN
+from qlib.contrib.model.gbdt import LGBModel
+from qlib.contrib.data.handler import Alpha158
+from qlib.contrib.strategy.strategy import TopkDropoutStrategy
+from qlib.contrib.evaluate import (
+ backtest as normal_backtest,
+ risk_analysis,
+)
+from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
+from qlib.workflow import R
+from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
+
+
+if __name__ == "__main__":
+
+ # use default data
+ provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
+ if not exists_qlib_data(provider_uri):
+ print(f"Qlib data is not found in {provider_uri}")
+ sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
+ from get_data import GetData
+
+ GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
+
+ qlib.init(provider_uri=provider_uri, region=REG_CN)
+
+ market = "csi300"
+ benchmark = "SH000300"
+
+ ###################################
+ # train model
+ ###################################
+ data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": market,
+ }
+
+ task = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ "kwargs": {
+ "loss": "mse",
+ "colsample_bytree": 0.8879,
+ "learning_rate": 0.0421,
+ "subsample": 0.8789,
+ "lambda_l1": 205.6999,
+ "lambda_l2": 580.9768,
+ "max_depth": 8,
+ "num_leaves": 210,
+ "num_threads": 20,
+ },
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+ },
+ }
+
+ port_analysis_config = {
+ "strategy": {
+ "class": "TopkDropoutStrategy",
+ "module_path": "qlib.contrib.strategy.strategy",
+ "kwargs": {
+ "topk": 50,
+ "n_drop": 5,
+ },
+ },
+ "backtest": {
+ "verbose": False,
+ "limit_threshold": 0.095,
+ "account": 100000000,
+ "benchmark": benchmark,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ },
+ }
+
+ # model initiaiton
+ model = init_instance_by_config(task["model"])
+ dataset = init_instance_by_config(task["dataset"])
+
+ # start exp
+ with R.start(experiment_name="workflow"):
+ R.log_params(**flatten_dict(task))
+ model.fit(dataset)
+
+ # prediction
+ recorder = R.get_recorder()
+ sr = SignalRecord(model, dataset, recorder)
+ sr.generate()
+
+ # backtest
+ par = PortAnaRecord(recorder, port_analysis_config)
+ par.generate()
diff --git a/qlib/__init__.py b/qlib/__init__.py
index f63aa26cc..2b8989303 100644
--- a/qlib/__init__.py
+++ b/qlib/__init__.py
@@ -2,19 +2,20 @@
# Licensed under the MIT License.
-__version__ = "0.5.1.dev0"
+__version__ = "0.6.0.alpha"
import os
-import copy
-import logging
import re
-import subprocess
-import platform
+import sys
+import copy
import yaml
+import logging
+import platform
+import subprocess
from pathlib import Path
-from .utils import can_use_cache
-
+from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path
+from .workflow.utils import experiment_exit_handler
# init qlib
def init(default_conf="client", **kwargs):
@@ -22,6 +23,7 @@ def init(default_conf="client", **kwargs):
from .data.data import register_all_wrappers
from .log import get_module_logger, set_log_with_config
from .data.cache import H
+ from .workflow import R, QlibRecorder
C.reset()
H.clear()
@@ -34,17 +36,18 @@ def init(default_conf="client", **kwargs):
if _logging_config:
set_log_with_config(_logging_config)
+ # FIXME: this logger ignored the level in config
LOG = get_module_logger("Initialization", level=logging.INFO)
LOG.info(f"default_conf: {default_conf}.")
C.set_mode(default_conf)
+ C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
for k, v in kwargs.items():
C[k] = v
if k not in C:
LOG.warning("Unrecognized config %s" % k)
- C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
C.resolve_path()
if not (C["expression_cache"] is None and C["dataset_cache"] is None):
@@ -61,12 +64,10 @@ def init(default_conf="client", **kwargs):
if not os.path.exists(C["provider_uri"]):
if C["auto_mount"]:
LOG.error(
- "Invalid provider uri: {}, please check if a valid provider uri has been set. This path does not exist.".format(
- C["provider_uri"]
- )
+ f"Invalid provider uri: {C['provider_uri']}, please check if a valid provider uri has been set. This path does not exist."
)
else:
- LOG.warning("auto_path is False, please make sure {} is mounted".format(C["mount_path"]))
+ LOG.warning(f"auto_path is False, please make sure {C['mount_path']} is mounted")
elif C.get_uri_type() == QlibConfig.NFS_URI:
_mount_nfs_uri(C)
else:
@@ -80,6 +81,13 @@ def init(default_conf="client", **kwargs):
if "flask_server" in C:
LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
+ # set up QlibRecorder
+ exp_manager = init_instance_by_config(C["exp_manager"])
+ qr = QlibRecorder(exp_manager)
+ R.register(qr)
+ # clean up experiment when python program ends
+ experiment_exit_handler()
+
def _mount_nfs_uri(C):
from .log import get_module_logger
@@ -94,9 +102,7 @@ def _mount_nfs_uri(C):
if not C["auto_mount"]:
if not os.path.exists(C["mount_path"]):
raise FileNotFoundError(
- "Invalid mount path: {}! Please mount manually: {} or Set init parameter `auto_mount=True`".format(
- C["mount_path"], mount_command
- )
+ f"Invalid mount path: {C['mount_path']}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
)
else:
# Judging system type
@@ -153,9 +159,7 @@ def _mount_nfs_uri(C):
os.makedirs(C["mount_path"], exist_ok=True)
except Exception:
raise OSError(
- "Failed to create directory {}, please create {} manually!".format(
- C["mount_path"], C["mount_path"]
- )
+ f"Failed to create directory {C['mount_path']}, please create {C['mount_path']} manually!"
)
# check nfs-common
@@ -167,20 +171,18 @@ def _mount_nfs_uri(C):
command_status = os.system(mount_command)
if command_status == 256:
raise OSError(
- "mount {} on {} error! Needs SUDO! Please mount manually: {}".format(
- C["provider_uri"], C["mount_path"], mount_command
- )
+ f"mount {C['provider_uri']} on {C['mount_path']} error! Needs SUDO! Please mount manually: {mount_command}"
)
elif command_status == 32512:
# LOG.error("Command error")
- raise OSError("mount {} on {} error! Command error".format(C["provider_uri"], C["mount_path"]))
+ raise OSError(f"mount {C['provider_uri']} on {C['mount_path']} error! Command error")
elif command_status == 0:
LOG.info("Mount finished")
else:
- LOG.warning("{} on {} is already mounted".format(_remote_uri, _mount_path))
+ LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")
-def init_from_yaml_conf(conf_path):
+def init_from_yaml_conf(conf_path, **kwargs):
"""init_from_yaml_conf
:param conf_path: A path to the qlib config in yml format
@@ -188,5 +190,6 @@ def init_from_yaml_conf(conf_path):
with open(conf_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
+ config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
diff --git a/qlib/config.py b/qlib/config.py
index 1d3ad86cb..869ea99c9 100644
--- a/qlib/config.py
+++ b/qlib/config.py
@@ -14,6 +14,8 @@ Two modes are supported
import copy
from pathlib import Path
import re
+import os
+import multiprocessing
class Config:
@@ -62,6 +64,8 @@ class Config:
REG_CN = "cn"
REG_US = "us"
+NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
+
_default_config = {
# data provider config
"calendar_provider": "LocalCalendarProvider",
@@ -78,7 +82,7 @@ _default_config = {
"calendar_cache": None,
# for simple dataset cache
"local_cache_path": None,
- "kernels": 16,
+ "kernels": NUM_USABLE_CPU,
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
"default_disk_cache": 1, # 0:skip/1:use
@@ -124,6 +128,15 @@ _default_config = {
},
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
},
+ # Defatult config for experiment manager
+ "exp_manager": {
+ "class": "MLflowExpManager",
+ "module_path": "qlib.workflow.expm",
+ "kwargs": {
+ "uri": "file:" + str(Path(os.getcwd()).resolve() / "mlruns"),
+ "default_exp_name": "Experiment",
+ },
+ },
}
MODE_CONF = {
@@ -141,10 +154,11 @@ MODE_CONF = {
"redis_host": "127.0.0.1",
"redis_port": 6379,
"redis_task_db": 1,
- "kernels": 64,
+ "kernels": NUM_USABLE_CPU,
# cache
"expression_cache": "DiskExpressionCache",
"dataset_cache": "DiskDatasetCache",
+ "mount_path": None,
},
"client": {
# data provider config
@@ -162,7 +176,7 @@ MODE_CONF = {
"dataset_cache": "DiskDatasetCache",
"calendar_cache": None,
# client config
- "kernels": 16,
+ "kernels": NUM_USABLE_CPU,
"mount_path": None,
"auto_mount": False, # The nfs is already mounted on our server[auto_mount: False].
# The nfs should be auto-mounted by qlib on other
@@ -212,7 +226,9 @@ class QlibConfig(Config):
def get_uri_type(self):
is_win = re.match("^[a-zA-Z]:.*", self["provider_uri"]) is not None # such as 'C:\\data', 'D:'
- is_nfs_or_win = re.match("^[^/]+:.+", self["provider_uri"]) is not None # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
+ is_nfs_or_win = (
+ re.match("^[^/]+:.+", self["provider_uri"]) is not None
+ ) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
if is_nfs_or_win and not is_win:
return QlibConfig.NFS_URI
diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py
index ea7220133..7ee8dceb0 100644
--- a/qlib/contrib/backtest/backtest.py
+++ b/qlib/contrib/backtest/backtest.py
@@ -10,6 +10,7 @@ from ...data import D
from .account import Account
from ...config import C
from ...log import get_module_logger
+from ...data.dataset.utils import get_level_index
LOG = get_module_logger("backtest")
@@ -18,7 +19,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
"""Parameters
----------
pred : pandas.DataFrame
- predict should has index and one `score` column
+ predict should has index and one `score` column
+ Qlib want to support multi-singal strategy in the future. So pd.Series is not used.
strategy : Strategy()
strategy part for backtest
trade_exchange : Exchange()
@@ -43,6 +45,12 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
`benchmark` is str, will use the daily change as the 'bench'.
benchmark code, default is SH000905 CSI500
"""
+ # Convert format if the input format is not expected
+ if get_level_index(pred, level="datetime") == 1:
+ pred = pred.swaplevel().sort_index()
+ if isinstance(pred, pd.Series):
+ pred = pred.to_frame("score")
+
trade_account = Account(init_cash=account)
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
@@ -57,6 +65,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
get_date_by_shift(predict_dates[-1], shift=shift),
disk_cache=1,
)
+ if len(_temp_result) == 0:
+ raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], shift=shift))
@@ -71,7 +81,7 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
# 1. Load the score_series at pred_date
try:
- score = pred.loc(axis=0)[:, pred_date] # (stock_id, trade_date) multi_index, score in pdate
+ score = pred.loc(axis=0)[pred_date, :] # (trade_date, stock_id) multi_index, score in pdate
score_series = score.reset_index(level="datetime", drop=True)[
"score"
] # pd.Series(index:stock_id, data: score)
diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py
index b20d7012f..6c269d505 100644
--- a/qlib/contrib/backtest/position.py
+++ b/qlib/contrib/backtest/position.py
@@ -166,7 +166,7 @@ class Position:
def save_position(self, path, last_trade_date):
path = pathlib.Path(path)
p = copy.deepcopy(self.position)
- cash = pd.Series()
+ cash = pd.Series(dtype=np.float)
cash["init_cash"] = self.init_cash
cash["cash"] = p["cash"]
cash["today_account_value"] = p["today_account_value"]
diff --git a/qlib/contrib/estimator/__init__.py b/qlib/contrib/data/__init__.py
similarity index 100%
rename from qlib/contrib/estimator/__init__.py
rename to qlib/contrib/data/__init__.py
diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py
new file mode 100644
index 000000000..e97b00c24
--- /dev/null
+++ b/qlib/contrib/data/handler.py
@@ -0,0 +1,429 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from ...data.dataset.handler import DataHandlerLP
+from ...data.dataset.processor import Processor
+from ...utils import get_cls_kwargs
+from ...data.dataset import processor as processor_module
+from ...log import TimeInspector
+from inspect import getfullargspec
+import copy
+
+
+def check_transform_proc(proc_l, fit_start_time, fit_end_time):
+ new_l = []
+ for p in proc_l:
+ if not isinstance(p, Processor):
+ klass, pkwargs = get_cls_kwargs(p, processor_module)
+ args = getfullargspec(klass).args
+ if "fit_start_time" in args and "fit_end_time" in args:
+ assert (
+ fit_start_time is not None and fit_end_time is not None
+ ), "Make sure `fit_start_time` and `fit_end_time` are not None."
+ pkwargs.update(
+ {
+ "fit_start_time": fit_start_time,
+ "fit_end_time": fit_end_time,
+ }
+ )
+ new_l.append({"class": klass.__name__, "kwargs": pkwargs})
+ else:
+ new_l.append(p)
+ return new_l
+
+
+class ALPHA360_Denoise(DataHandlerLP):
+ def __init__(self, instruments="csi500", start_time=None, end_time=None, fit_start_time=None, fit_end_time=None):
+ data_loader = {
+ "class": "QlibDataLoader",
+ "kwargs": {
+ "config": {
+ "feature": self.get_feature_config(),
+ "label": self.get_label_config(),
+ },
+ },
+ }
+
+ learn_processors = [
+ {"class": "DropnaLabel", "kwargs": {"fields_group": "label"}},
+ {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
+ ]
+ infer_processors = [
+ {"class": "ProcessInf", "kwargs": {}},
+ {"class": "TanhProcess", "kwargs": {}},
+ {"class": "Fillna", "kwargs": {}},
+ ]
+
+ super().__init__(
+ instruments,
+ start_time,
+ end_time,
+ data_loader=data_loader,
+ learn_processors=learn_processors,
+ infer_processors=infer_processors,
+ )
+
+ def get_label_config(self):
+ return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
+
+ def get_feature_config(self):
+
+ fields = []
+ names = []
+
+ for i in range(59, 0, -1):
+ fields += ["Ref($close, %d)/$close" % (i)]
+ names += ["CLOSE%d" % (i)]
+ fields += ["$close/$close"]
+ names += ["CLOSE0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($open, %d)/$close" % (i)]
+ names += ["OPEN%d" % (i)]
+ fields += ["$open/$close"]
+ names += ["OPEN0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($high, %d)/$close" % (i)]
+ names += ["HIGH%d" % (i)]
+ fields += ["$high/$close"]
+ names += ["HIGH0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($low, %d)/$close" % (i)]
+ names += ["LOW%d" % (i)]
+ fields += ["$low/$close"]
+ names += ["LOW0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($vwap, %d)/$close" % (i)]
+ names += ["VWAP%d" % (i)]
+ fields += ["$vwap/$close"]
+ names += ["VWAP0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($volume, %d)/$volume" % (i)]
+ names += ["VOLUME%d" % (i)]
+ fields += ["$volume/$volume"]
+ names += ["VOLUME0"]
+
+ return fields, names
+
+
+_DEFAULT_LEARN_PROCESSORS = [
+ {"class": "DropnaLabel"},
+ {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
+]
+_DEFAULT_INFER_PROCESSORS = [
+ {"class": "ProcessInf", "kwargs": {}},
+ {"class": "ZScoreNorm", "kwargs": {}},
+ {"class": "Fillna", "kwargs": {}},
+]
+
+
+class ALPHA360(DataHandlerLP):
+ def __init__(
+ self,
+ instruments="csi500",
+ start_time=None,
+ end_time=None,
+ infer_processors=_DEFAULT_INFER_PROCESSORS,
+ learn_processors=_DEFAULT_LEARN_PROCESSORS,
+ fit_start_time=None,
+ fit_end_time=None,
+ **kwargs,
+ ):
+ infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
+ learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
+
+ data_loader = {
+ "class": "QlibDataLoader",
+ "kwargs": {
+ "config": {
+ "feature": self.get_feature_config(),
+ "label": kwargs.get("label", self.get_label_config()),
+ },
+ },
+ }
+
+ super().__init__(
+ instruments,
+ start_time,
+ end_time,
+ data_loader=data_loader,
+ learn_processors=learn_processors,
+ infer_processors=infer_processors,
+ )
+
+ def get_label_config(self):
+ return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
+
+ def get_feature_config(self):
+
+ fields = []
+ names = []
+
+ for i in range(59, 0, -1):
+ fields += ["Ref($close, %d)/$close" % (i)]
+ names += ["CLOSE%d" % (i)]
+ fields += ["$close/$close"]
+ names += ["CLOSE0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($open, %d)/$close" % (i)]
+ names += ["OPEN%d" % (i)]
+ fields += ["$open/$close"]
+ names += ["OPEN0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($high, %d)/$close" % (i)]
+ names += ["HIGH%d" % (i)]
+ fields += ["$high/$close"]
+ names += ["HIGH0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($low, %d)/$close" % (i)]
+ names += ["LOW%d" % (i)]
+ fields += ["$low/$close"]
+ names += ["LOW0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($vwap, %d)/$close" % (i)]
+ names += ["VWAP%d" % (i)]
+ fields += ["$vwap/$close"]
+ names += ["VWAP0"]
+ for i in range(59, 0, -1):
+ fields += ["Ref($volume, %d)/$volume" % (i)]
+ names += ["VOLUME%d" % (i)]
+ fields += ["$volume/$volume"]
+ names += ["VOLUME0"]
+
+ return fields, names
+
+
+class ALPHA360vwap(ALPHA360):
+ def get_label_config(self):
+ return (["Ref($vwap, -2)/Ref($vwap, -1) - 1"], ["LABEL0"])
+
+
+class Alpha158(DataHandlerLP):
+ def __init__(
+ self,
+ instruments="csi500",
+ start_time=None,
+ end_time=None,
+ infer_processors=[],
+ learn_processors=_DEFAULT_LEARN_PROCESSORS,
+ fit_start_time=None,
+ fit_end_time=None,
+ process_type=DataHandlerLP.PTYPE_A,
+ **kwargs,
+ ):
+ infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
+ learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
+
+ data_loader = {
+ "class": "QlibDataLoader",
+ "kwargs": {
+ "config": {"feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config())},
+ },
+ }
+ super().__init__(
+ instruments,
+ start_time,
+ end_time,
+ data_loader=data_loader,
+ infer_processors=infer_processors,
+ learn_processors=learn_processors,
+ process_type=process_type,
+ )
+
+ def get_feature_config(self):
+ conf = {
+ "kbar": {},
+ "price": {
+ "windows": [0],
+ "feature": ["OPEN", "HIGH", "LOW", "VWAP"],
+ },
+ "rolling": {},
+ }
+ return self.parse_config_to_fields(conf)
+
+ def get_label_config(self):
+ return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
+
+ @staticmethod
+ def parse_config_to_fields(config):
+ """create factors from config
+
+ config = {
+ 'kbar': {}, # whether to use some hard-code kbar features
+ 'price': { # whether to use raw price features
+ 'windows': [0, 1, 2, 3, 4], # use price at n days ago
+ 'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
+ },
+ 'volume': { # whether to use raw volume features
+ 'windows': [0, 1, 2, 3, 4], # use volume at n days ago
+ },
+ 'rolling': { # whether to use rolling operator based features
+ 'windows': [5, 10, 20, 30, 60], # rolling windows size
+ 'include': ['ROC', 'MA', 'STD'], # rolling operator to use
+ #if include is None we will use default operators
+ 'exclude': ['RANK'], # rolling operator not to use
+ }
+ }
+ """
+ fields = []
+ names = []
+ if "kbar" in config:
+ fields += [
+ "($close-$open)/$open",
+ "($high-$low)/$open",
+ "($close-$open)/($high-$low+1e-12)",
+ "($high-Greater($open, $close))/$open",
+ "($high-Greater($open, $close))/($high-$low+1e-12)",
+ "(Less($open, $close)-$low)/$open",
+ "(Less($open, $close)-$low)/($high-$low+1e-12)",
+ "(2*$close-$high-$low)/$open",
+ "(2*$close-$high-$low)/($high-$low+1e-12)",
+ ]
+ names += [
+ "KMID",
+ "KLEN",
+ "KMID2",
+ "KUP",
+ "KUP2",
+ "KLOW",
+ "KLOW2",
+ "KSFT",
+ "KSFT2",
+ ]
+ if "price" in config:
+ windows = config["price"].get("windows", range(5))
+ feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
+ for field in feature:
+ field = field.lower()
+ fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
+ names += [field.upper() + str(d) for d in windows]
+ if "volume" in config:
+ windows = config["volume"].get("windows", range(5))
+ fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
+ names += ["VOLUME" + str(d) for d in windows]
+ if "rolling" in config:
+ windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
+ include = config["rolling"].get("include", None)
+ exclude = config["rolling"].get("exclude", [])
+ # `exclude` in dataset config unnecessary filed
+ # `include` in dataset config necessary field
+ use = lambda x: x not in exclude and (include is None or x in include)
+ if use("ROC"):
+ fields += ["Ref($close, %d)/$close" % d for d in windows]
+ names += ["ROC%d" % d for d in windows]
+ if use("MA"):
+ fields += ["Mean($close, %d)/$close" % d for d in windows]
+ names += ["MA%d" % d for d in windows]
+ if use("STD"):
+ fields += ["Std($close, %d)/$close" % d for d in windows]
+ names += ["STD%d" % d for d in windows]
+ if use("BETA"):
+ fields += ["Slope($close, %d)/$close" % d for d in windows]
+ names += ["BETA%d" % d for d in windows]
+ if use("RSQR"):
+ fields += ["Rsquare($close, %d)" % d for d in windows]
+ names += ["RSQR%d" % d for d in windows]
+ if use("RESI"):
+ fields += ["Resi($close, %d)/$close" % d for d in windows]
+ names += ["RESI%d" % d for d in windows]
+ if use("MAX"):
+ fields += ["Max($high, %d)/$close" % d for d in windows]
+ names += ["MAX%d" % d for d in windows]
+ if use("LOW"):
+ fields += ["Min($low, %d)/$close" % d for d in windows]
+ names += ["MIN%d" % d for d in windows]
+ if use("QTLU"):
+ fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
+ names += ["QTLU%d" % d for d in windows]
+ if use("QTLD"):
+ fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
+ names += ["QTLD%d" % d for d in windows]
+ if use("RANK"):
+ fields += ["Rank($close, %d)" % d for d in windows]
+ names += ["RANK%d" % d for d in windows]
+ if use("RSV"):
+ fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
+ names += ["RSV%d" % d for d in windows]
+ if use("IMAX"):
+ fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
+ names += ["IMAX%d" % d for d in windows]
+ if use("IMIN"):
+ fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
+ names += ["IMIN%d" % d for d in windows]
+ if use("IMXD"):
+ fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
+ names += ["IMXD%d" % d for d in windows]
+ if use("CORR"):
+ fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
+ names += ["CORR%d" % d for d in windows]
+ if use("CORD"):
+ fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
+ names += ["CORD%d" % d for d in windows]
+ if use("CNTP"):
+ fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
+ names += ["CNTP%d" % d for d in windows]
+ if use("CNTN"):
+ fields += ["Mean($close[Ref($close, 1), %d)-Mean($close][= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
+ if self.fillna_feature:
+ x.fillna(0, inplace=True)
+ return x
+
+ TimeInspector.set_time_mark()
+
+ # Copy the focus part and change it to single level
+ selected_cols = get_group_columns(df, self.fields_group)
+ df_focus = df[selected_cols].copy()
+ if len(df_focus.columns.levels) > 1:
+ df_focus = df_focus.droplevel(level=0)
+
+ # Label
+ cols = df_focus.columns[df_focus.columns.str.contains("^LABEL")]
+ df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_label_norm)
+
+ # Features
+ cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
+ df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
+
+ cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
+ df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
+
+ _cols = [
+ "KMID",
+ "KSFT",
+ "OPEN",
+ "HIGH",
+ "LOW",
+ "CLOSE",
+ "VWAP",
+ "ROC",
+ "MA",
+ "BETA",
+ "RESI",
+ "QTLU",
+ "QTLD",
+ "RSV",
+ "SUMP",
+ "SUMN",
+ "SUMD",
+ "VSUMP",
+ "VSUMN",
+ "VSUMD",
+ ]
+ pat = "|".join(["^" + x for x in _cols])
+ cols = df_focus.columns[df_focus.columns.str.contains(pat) & (~df_focus.columns.isin(["HIGH0", "LOW0"]))]
+ df_focus[cols] = df_focus[cols].groupby(level="datetime").apply(_feature_norm)
+
+ cols = df_focus.columns[df_focus.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
+ df_focus[cols] = df_focus[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
+
+ cols = df_focus.columns[df_focus.columns.str.contains("^RSQR")]
+ df_focus[cols] = df_focus[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
+
+ cols = df_focus.columns[df_focus.columns.str.contains("^MAX|^HIGH0")]
+ df_focus[cols] = df_focus[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
+
+ cols = df_focus.columns[df_focus.columns.str.contains("^MIN|^LOW0")]
+ df_focus[cols] = df_focus[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
+
+ cols = df_focus.columns[df_focus.columns.str.contains("^CORR|^CORD")]
+ df_focus[cols] = df_focus[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
+
+ cols = df_focus.columns[df_focus.columns.str.contains("^WVMA")]
+ df_focus[cols] = df_focus[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
+
+ df[selected_cols] = df_focus.values
+
+ TimeInspector.log_cost_time("Finished preprocessing data.")
+
+ return df
diff --git a/qlib/contrib/estimator/config.py b/qlib/contrib/estimator/config.py
deleted file mode 100644
index 5a4a31613..000000000
--- a/qlib/contrib/estimator/config.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import yaml
-import copy
-import os
-import json
-import tempfile
-from pathlib import Path
-from ...config import REG_CN
-
-
-class EstimatorConfigManager(object):
- def __init__(self, config_path):
-
- if not config_path:
- raise ValueError("Config path is invalid.")
- self.config_path = config_path
-
- with open(config_path) as fp:
- config = yaml.load(fp, Loader=yaml.FullLoader)
- self.config = copy.deepcopy(config)
-
- self.ex_config = ExperimentConfig(config.get("experiment", dict()), self)
- self.data_config = DataConfig(config.get("data", dict()), self)
- self.model_config = ModelConfig(config.get("model", dict()), self)
- self.trainer_config = TrainerConfig(config.get("trainer", dict()), self)
- self.strategy_config = StrategyConfig(config.get("strategy", dict()), self)
- self.backtest_config = BacktestConfig(config.get("backtest", dict()), self)
- self.qlib_data_config = QlibDataConfig(config.get("qlib_data", dict()), self)
-
- # If the start_date and end_date are not given in data_config, they will be referred from the trainer_config.
- handler_start_date = self.data_config.handler_parameters.get("start_date", None)
- handler_end_date = self.data_config.handler_parameters.get("end_date", None)
- if handler_start_date is None:
- self.data_config.handler_parameters["start_date"] = self.trainer_config.parameters["train_start_date"]
- if handler_end_date is None:
- self.data_config.handler_parameters["end_date"] = self.trainer_config.parameters["test_end_date"]
-
-
-class ExperimentConfig(object):
- TRAIN_MODE = "train"
- TEST_MODE = "test"
-
- OBSERVER_FILE_STORAGE = "file_storage"
- OBSERVER_MONGO = "mongo"
-
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for experiment
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.name = config.get("name", "test_experiment")
- # The dir of the result of all the experiments
- self.global_dir = config.get("dir", os.path.dirname(CONFIG_MANAGER.config_path))
- # The dir of the result of current experiment
- self.ex_dir = os.path.join(self.global_dir, self.name)
- if not os.path.exists(self.ex_dir):
- os.makedirs(self.ex_dir)
- self.tmp_run_dir = tempfile.mkdtemp(dir=self.ex_dir)
- self.mode = config.get("mode", ExperimentConfig.TRAIN_MODE)
- self.sacred_dir = os.path.join(self.ex_dir, "sacred")
- self.observer_type = config.get("observer_type", ExperimentConfig.OBSERVER_FILE_STORAGE)
- self.mongo_url = config.get("mongo_url", None)
- self.db_name = config.get("db_name", None)
- self.finetune = config.get("finetune", False)
-
- # The path of the experiment id of the experiment
- self.exp_info_path = config.get("exp_info_path", os.path.join(self.ex_dir, "exp_info.json"))
- exp_info_dir = Path(self.exp_info_path).parent
- exp_info_dir.mkdir(parents=True, exist_ok=True)
-
- # Test mode config
- loader_args = config.get("loader", dict())
- if self.mode == ExperimentConfig.TEST_MODE or self.finetune:
- loader_exp_info_path = loader_args.get("exp_info_path", None)
- self.loader_model_index = loader_args.get("model_index", None)
- if (loader_exp_info_path is not None) and (os.path.exists(loader_exp_info_path)):
- with open(loader_exp_info_path) as fp:
- loader_dict = json.load(fp)
- for k, v in loader_dict.items():
- setattr(self, "loader_{}".format(k), v)
- # Check loader experiment id
- assert hasattr(self, "loader_id"), "If mode is test or finetune is True, loader must contain id."
- else:
- self.loader_id = loader_args.get("id", None)
- if self.loader_id is None:
- raise ValueError("If mode is test or finetune is True, loader must contain id.")
-
- self.loader_observer_type = loader_args.get("observer_type", self.observer_type)
- self.loader_name = loader_args.get("name", self.name)
- self.loader_dir = loader_args.get("dir", self.global_dir)
-
- self.loader_mongo_url = loader_args.get("mongo_url", self.mongo_url)
- self.loader_db_name = loader_args.get("db_name", self.db_name)
-
-
-class DataConfig(object):
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for data
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.handler_module_path = config.get("module_path", "qlib.contrib.estimator.handler")
- self.handler_class = config.get("class", "ALPHA360")
- self.handler_parameters = config.get("args", dict())
- self.handler_filter = config.get("filter", dict())
- # Update provider uri.
-
-
-class ModelConfig(object):
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for model
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.model_class = config.get("class", "Model")
- self.model_module_path = config.get("module_path", "qlib.contrib.model")
- self.save_dir = os.path.join(CONFIG_MANAGER.ex_config.tmp_run_dir, "model")
- self.save_path = config.get("save_path", os.path.join(self.save_dir, "model.bin"))
- self.parameters = config.get("args", dict())
- # Make dir if need.
- if not os.path.exists(self.save_dir):
- os.makedirs(self.save_dir)
-
-
-class TrainerConfig(object):
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for trainer
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.trainer_class = config.get("class", "StaticTrainer")
- self.trainer_module_path = config.get("module_path", "qlib.contrib.estimator.trainer")
- self.parameters = config.get("args", dict())
-
-
-class StrategyConfig(object):
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for strategy
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.strategy_class = config.get("class", "TopkDropoutStrategy")
- self.strategy_module_path = config.get("module_path", "qlib.contrib.strategy.strategy")
- self.parameters = config.get("args", dict())
-
-
-class BacktestConfig(object):
- def __init__(self, config, CONFIG_MANAGE):
- """__init__
-
- :param config: The config dict for strategy
- :param CONFIG_MANAGE: The estimator config manager
- """
- self.normal_backtest_parameters = config.get("normal_backtest_args", dict())
- self.long_short_backtest_parameters = config.get("long_short_backtest_args", dict())
-
-
-class QlibDataConfig(object):
- def __init__(self, config, CONFIG_MANAGE):
- """__init__
-
- :param config: The config dict for qlib_client
- :param CONFIG_MANAGE: The estimator config manager
- """
- self.provider_uri = config.pop("provider_uri", "~/.qlib/qlib_data/cn_data")
- self.auto_mount = config.pop("auto_mount", False)
- self.mount_path = config.pop("mount_path", "~/.qlib/qlib_data/cn_data")
- self.region = config.pop("region", REG_CN)
- self.args = config
diff --git a/qlib/contrib/estimator/estimator.py b/qlib/contrib/estimator/estimator.py
deleted file mode 100644
index 56495e5eb..000000000
--- a/qlib/contrib/estimator/estimator.py
+++ /dev/null
@@ -1,328 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-# coding=utf-8
-
-import pandas as pd
-
-import os
-import copy
-import json
-import yaml
-import pickle
-
-import qlib
-from ..evaluate import risk_analysis
-from ..evaluate import backtest as normal_backtest
-from ..evaluate import long_short_backtest
-from .config import ExperimentConfig
-from .fetcher import create_fetcher_with_config
-
-from ...log import get_module_logger, TimeInspector
-from ...utils import get_module_by_module_path, compare_dict_value
-
-
-class Estimator(object):
- def __init__(self, config_manager, sacred_ex):
-
- # Set logger.
- self.logger = get_module_logger("Estimator")
-
- # 1. Set config manager.
- self.config_manager = config_manager
-
- # 2. Set configs.
- self.ex_config = config_manager.ex_config
- self.data_config = config_manager.data_config
- self.model_config = config_manager.model_config
- self.trainer_config = config_manager.trainer_config
- self.strategy_config = config_manager.strategy_config
- self.backtest_config = config_manager.backtest_config
-
- # If experiment.mode is test or experiment.finetune is True, load the experimental results in the loader
- if self.ex_config.mode == self.ex_config.TEST_MODE or self.ex_config.finetune:
- self.compare_config_with_config_manger(self.config_manager)
-
- # 3. Set sacred_experiment.
- self.ex = sacred_ex
-
- # 4. Init data handler.
- self.data_handler = None
- self._init_data_handler()
-
- # 5. Init trainer.
- self.trainer = None
- self._init_trainer()
-
- # 6. Init strategy.
- self.strategy = None
- self._init_strategy()
-
- def _init_data_handler(self):
- handler_module = get_module_by_module_path(self.data_config.handler_module_path)
-
- # Set market
- market = self.data_config.handler_filter.get("market", None)
- if market is None:
- if "market" in self.data_config.handler_parameters:
- self.logger.warning(
- "Warning: The market in data.args section is deprecated. "
- "It only works when market is not set in data.filter section. "
- "It will be overridden by market in the data.filter section."
- )
- market = self.data_config.handler_parameters["market"]
- else:
- market = "csi500"
-
- self.data_config.handler_parameters["market"] = market
-
- data_filter_list = []
- handler_filters = self.data_config.handler_filter.get("filter_pipeline", list())
- for h_filter in handler_filters:
- filter_module_path = h_filter.get("module_path", "qlib.data.filter")
- filter_class_name = h_filter.get("class", "")
- filter_parameters = h_filter.get("args", {})
- filter_module = get_module_by_module_path(filter_module_path)
- filter_class = getattr(filter_module, filter_class_name)
- data_filter = filter_class(**filter_parameters)
- data_filter_list.append(data_filter)
-
- self.data_config.handler_parameters["data_filter_list"] = data_filter_list
- handler_class = getattr(handler_module, self.data_config.handler_class)
- self.data_handler = handler_class(**self.data_config.handler_parameters)
-
- def _init_trainer(self):
-
- model_module = get_module_by_module_path(self.model_config.model_module_path)
- trainer_module = get_module_by_module_path(self.trainer_config.trainer_module_path)
- model_class = getattr(model_module, self.model_config.model_class)
- trainer_class = getattr(trainer_module, self.trainer_config.trainer_class)
-
- self.trainer = trainer_class(
- model_class,
- self.model_config.save_path,
- self.model_config.parameters,
- self.data_handler,
- self.ex,
- **self.trainer_config.parameters
- )
-
- def _init_strategy(self):
-
- module = get_module_by_module_path(self.strategy_config.strategy_module_path)
- strategy_class = getattr(module, self.strategy_config.strategy_class)
- self.strategy = strategy_class(**self.strategy_config.parameters)
-
- def run(self):
- if self.ex_config.mode == ExperimentConfig.TRAIN_MODE:
- self.trainer.train()
- elif self.ex_config.mode == ExperimentConfig.TEST_MODE:
- self.trainer.load()
- else:
- raise ValueError("unexpected mode: %s" % self.ex_config.mode)
- analysis = self.backtest()
- print(analysis)
- self.logger.info(
- "experiment id: {}, experiment name: {}".format(self.ex.experiment.current_run._id, self.ex_config.name)
- )
-
- # Remove temp dir
- # shutil.rmtree(self.ex_config.tmp_run_dir)
-
- def backtest(self):
- TimeInspector.set_time_mark()
- # 1. Get pred and prediction score of model(s).
- pred = self.trainer.get_test_score()
- try:
- performance = self.trainer.get_test_performance()
- except NotImplementedError:
- performance = None
- # 2. Normal Backtest.
- report_normal, positions_normal = self._normal_backtest(pred)
- # 3. Long-Short Backtest.
- # Deprecated
- # long_short_reports = self._long_short_backtest(pred)
- # 4. Analyze
- analysis_df = self._analyze(report_normal)
- # 5. Save.
- self._save_backtest_result(
- pred,
- analysis_df,
- positions_normal,
- report_normal,
- # long_short_reports,
- performance,
- )
- return analysis_df
-
- def _normal_backtest(self, pred):
- TimeInspector.set_time_mark()
- if "account" not in self.backtest_config.normal_backtest_parameters:
- if "account" in self.strategy_config.parameters:
- self.logger.warning(
- "Warning: The account in strategy section is deprecated. "
- "It only works when account is not set in backtest section. "
- "It will be overridden by account in the backtest section."
- )
- self.backtest_config.normal_backtest_parameters["account"] = self.strategy_config.parameters["account"]
- report_normal, positions_normal = normal_backtest(
- pred, strategy=self.strategy, **self.backtest_config.normal_backtest_parameters
- )
- TimeInspector.log_cost_time("Finished normal backtest.")
- return report_normal, positions_normal
-
- def _long_short_backtest(self, pred):
- TimeInspector.set_time_mark()
- long_short_reports = long_short_backtest(pred, **self.backtest_config.long_short_backtest_parameters)
- TimeInspector.log_cost_time("Finished long-short backtest.")
- return long_short_reports
-
- @staticmethod
- def _analyze(report_normal):
- TimeInspector.set_time_mark()
-
- analysis = dict()
- # analysis["pred_long"] = risk_analysis(long_short_reports["long"])
- # analysis["pred_short"] = risk_analysis(long_short_reports["short"])
- # analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"])
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- TimeInspector.log_cost_time(
- "Finished generating analysis," " average turnover is: {0:.4f}.".format(report_normal["turnover"].mean())
- )
- return analysis_df
-
- def _save_backtest_result(self, pred, analysis, positions, report_normal, performance):
- # 1. Result dir.
- result_dir = os.path.join(self.config_manager.ex_config.tmp_run_dir, "result")
- if not os.path.exists(result_dir):
- os.makedirs(result_dir)
-
- self.ex.add_info(
- "task_config",
- json.loads(json.dumps(self.config_manager.config, default=str)),
- )
-
- # 2. Pred.
- TimeInspector.set_time_mark()
- pred_pkl_path = os.path.join(result_dir, "pred.pkl")
- pred.to_pickle(pred_pkl_path)
- self.ex.add_artifact(pred_pkl_path)
- TimeInspector.log_cost_time("Finished saving pred.pkl to: {}".format(pred_pkl_path))
-
- # 3. Ana.
- TimeInspector.set_time_mark()
- analysis_pkl_path = os.path.join(result_dir, "analysis.pkl")
- analysis.to_pickle(analysis_pkl_path)
- self.ex.add_artifact(analysis_pkl_path)
- TimeInspector.log_cost_time("Finished saving analysis.pkl to: {}".format(analysis_pkl_path))
-
- # 4. Pos.
- TimeInspector.set_time_mark()
- positions_pkl_path = os.path.join(result_dir, "positions.pkl")
- with open(positions_pkl_path, "wb") as fp:
- pickle.dump(positions, fp)
- self.ex.add_artifact(positions_pkl_path)
- TimeInspector.log_cost_time("Finished saving positions.pkl to: {}".format(positions_pkl_path))
-
- # 5. Report normal.
- TimeInspector.set_time_mark()
- report_normal_pkl_path = os.path.join(result_dir, "report_normal.pkl")
- report_normal.to_pickle(report_normal_pkl_path)
- self.ex.add_artifact(report_normal_pkl_path)
- TimeInspector.log_cost_time("Finished saving report_normal.pkl to: {}".format(report_normal_pkl_path))
-
- # 6. Report long short.
- # Deprecated
- # for k, name in zip(
- # ["long", "short", "long_short"],
- # ["report_long.pkl", "report_short.pkl", "report_long_short.pkl"],
- # ):
- # TimeInspector.set_time_mark()
- # pkl_path = os.path.join(result_dir, name)
- # long_short_reports[k].to_pickle(pkl_path)
- # self.ex.add_artifact(pkl_path)
- # TimeInspector.log_cost_time("Finished saving {} to: {}".format(name, pkl_path))
-
- # 7. Origin test label.
- TimeInspector.set_time_mark()
- label_pkl_path = os.path.join(result_dir, "label.pkl")
- self.data_handler.get_origin_test_label_with_date(
- self.trainer_config.parameters["test_start_date"],
- self.trainer_config.parameters["test_end_date"],
- ).to_pickle(label_pkl_path)
- self.ex.add_artifact(label_pkl_path)
- TimeInspector.log_cost_time("Finished saving label.pkl to: {}".format(label_pkl_path))
-
- # 8. Experiment info, save the model(s) performance here.
- TimeInspector.set_time_mark()
- cur_ex_id = self.ex.experiment.current_run._id
- exp_info = {
- "id": cur_ex_id,
- "name": self.ex_config.name,
- "performance": performance,
- "observer_type": self.ex_config.observer_type,
- }
-
- if self.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
- exp_info.update(
- {
- "mongo_url": self.ex_config.mongo_url,
- "db_name": self.ex_config.db_name,
- }
- )
- else:
- exp_info.update({"dir": self.ex_config.global_dir})
-
- with open(self.ex_config.exp_info_path, "w") as fp:
- json.dump(exp_info, fp, indent=4, sort_keys=True)
- self.ex.add_artifact(self.ex_config.exp_info_path)
- TimeInspector.log_cost_time("Finished saving ex_info to: {}".format(self.ex_config.exp_info_path))
-
- @staticmethod
- def compare_config_with_config_manger(config_manager):
- """Compare loader model args and current config with ConfigManage
-
- :param config_manager: ConfigManager
- :return:
- """
- fetcher = create_fetcher_with_config(config_manager, load_form_loader=True)
- loader_mode_config = fetcher.get_experiment(
- exp_name=config_manager.ex_config.loader_name,
- exp_id=config_manager.ex_config.loader_id,
- fields=["task_config"],
- )["task_config"]
- with open(config_manager.config_path) as fp:
- current_config = yaml.load(fp.read())
- current_config = json.loads(json.dumps(current_config, default=str))
-
- logger = get_module_logger("Estimator")
-
- loader_mode_config = copy.deepcopy(loader_mode_config)
- current_config = copy.deepcopy(current_config)
-
- # Require test_mode_config.test_start_date <= current_config.test_start_date
- loader_trainer_args = loader_mode_config.get("trainer", {}).get("args", {})
- cur_trainer_args = current_config.get("trainer", {}).get("args", {})
- loader_start_date = loader_trainer_args.pop("test_start_date")
- cur_test_start_date = cur_trainer_args.pop("test_start_date")
- assert (
- loader_start_date <= cur_test_start_date
- ), "Require: loader_mode_config.test_start_date <= current_config.test_start_date"
-
- # TODO: For the user's own extended `Trainer`, the support is not very good
- if "RollingTrainer" == current_config.get("trainer", {}).get("class", None):
- loader_period = loader_trainer_args.pop("rolling_period")
- cur_period = cur_trainer_args.pop("rolling_period")
- assert (
- loader_period == cur_period
- ), "Require: loader_mode_config.rolling_period == current_config.rolling_period"
-
- compare_section = ["trainer", "model", "data"]
- for section in compare_section:
- changes = compare_dict_value(loader_mode_config.get(section, {}), current_config.get(section, {}))
- if changes:
- logger.warning("Warning: Loader mode config and current config, `{}` are different:\n".format(section))
diff --git a/qlib/contrib/estimator/fetcher.py b/qlib/contrib/estimator/fetcher.py
deleted file mode 100644
index 16ef1dc60..000000000
--- a/qlib/contrib/estimator/fetcher.py
+++ /dev/null
@@ -1,290 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-# coding=utf-8
-
-import copy
-import json
-import yaml
-import pickle
-import gridfs
-import pymongo
-from pathlib import Path
-from abc import abstractmethod
-
-from .config import EstimatorConfigManager, ExperimentConfig
-
-
-class Fetcher(object):
- """Sacred Experiments Fetcher"""
-
- @abstractmethod
- def _get_experiment(self, exp_name, exp_id):
- """Get experiment basic info with experiment and experiment id
-
- :param exp_name: experiment name
- :param exp_id: experiment id
- :return: dict
- Must contain keys: _id, experiment, info, stop_time.
- Here is an example below for FileFetcher.
- exp = {
- '_id': exp_id, # experiment id
- 'path': path, # experiment result path
- 'experiment': {'name': exp_name}, # experiment
- 'info': info, # experiment config info
- 'stop_time': run.get('stop_time', None) # The time the experiment ended
- }
-
- """
- pass
-
- @abstractmethod
- def _list_experiments(self, exp_name=None):
- """Get experiment basic info list with experiment name
-
- :param exp_name: experiment name
- :return: list
-
- """
- pass
-
- @abstractmethod
- def _iter_artifacts(self, experiment):
- """Get information about the data in the experiment results
-
- :param experiment: `self._get_experiment` method result
- :return: iterable
- Each element contains two elements.
- first element : data name
- second element : data uri
- """
- pass
-
- @abstractmethod
- def _load_data(self, uri):
- """Load data with uri
-
- :param uri: data uri
- :return: bytes
- """
- pass
-
- @staticmethod
- def model_dict_to_buffer_list(model_dict):
- """
-
- :param model_dict:
- :return:
- """
- model_list = []
- is_static_model = False
- if len(model_dict) == 1 and list(model_dict.keys())[0] == "model.bin":
- is_static_model = True
- model_list.append(list(model_dict.values())[0])
- else:
- sep = "model.bin_"
- model_ids = list(map(lambda x: int(x.split(sep)[1]), model_dict.keys()))
- min_id, max_id = min(model_ids), max(model_ids)
- for i in range(min_id, max_id + 1):
- model_key = sep + str(i)
- model = model_dict.get(model_key, None)
- if model is None:
- print(
- "WARNING: In Fetcher, {} is missing when the get model is in the get_experiment function.".format(
- model_key
- )
- )
- break
- else:
- model_list.append(model)
-
- if is_static_model:
- return model_list[0]
-
- return model_list
-
- def get_experiments(self, exp_name=None):
- """Get experiments with name.
-
- :param exp_name: str
- If `exp_name` is set to None, then all experiments will return.
- :return: dict
- Experiments info dict(Including experiment id and task_config to run the
- experiment). Here is an example below.
- {
- 'a_experiment': [
- {
- 'id': '1',
- 'task_config': {...}
- },
- ...
- ]
- ...
- }
- """
- res = dict()
- for ex in self._list_experiments(exp_name):
- name = ex["experiment"]["name"]
- tmp = {
- "id": ex["_id"],
- "task_config": ex["info"].get("task_config", {}),
- "ex_run_stop_time": ex.get("stop_time", None),
- }
- res.setdefault(name, []).append(tmp)
- return res
-
- def get_experiment(self, exp_name, exp_id, fields=None):
- """
-
- :param exp_name:
- :param exp_id:
- :param fields: list
- Experiment result fields, if fields is None, will get all fields.
- Currently supported fields:
- ['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
- :return: dict
- """
- fields = copy.copy(fields)
- ex = self._get_experiment(exp_name, exp_id)
- results = dict()
- model_dict = dict()
- for name, uri in self._iter_artifacts(ex):
- # When saving, use `sacred.experiment.add_artifact(filename)` , so `name` is os.path.basename(filename)
- prefix = name.split(".")[0]
- if fields and prefix not in fields:
- continue
- data = self._load_data(uri)
- if prefix == "model":
- model_dict[name] = data
- else:
- results[prefix] = pickle.loads(data)
- # Sort model
- if model_dict:
- results["model"] = self.model_dict_to_buffer_list(model_dict)
-
- # Info
- results["task_config"] = ex["info"].get("task_config", {})
- return results
-
- def estimator_config_to_dict(self, exp_name, exp_id):
- """Save configuration to file
-
- :param exp_name:
- :param exp_id:
- :return: config dict
- """
-
- return self.get_experiment(exp_name, exp_id, fields=["task_config"])["task_config"]
-
-
-class FileFetcher(Fetcher):
- """File Fetcher"""
-
- def __init__(self, experiments_dir):
- self.experiments_dir = Path(experiments_dir)
-
- def _get_experiment(self, exp_name, exp_id):
- path = self.experiments_dir / exp_name / "sacred" / str(exp_id)
- info_path = path / "info.json"
- run_path = path / "run.json"
-
- if info_path.exists():
- with info_path.open("r") as f:
- info = json.load(f)
- else:
- info = {}
-
- if run_path.exists():
- with run_path.open("r") as f:
- run = json.load(f)
- else:
- run = {}
-
- exp = {
- "_id": exp_id,
- "path": path,
- "experiment": {"name": exp_name},
- "info": info,
- "stop_time": run.get("stop_time", None),
- }
- return exp
-
- def _list_experiments(self, exp_name=None):
- runs = []
- for path in self.experiments_dir.glob("{}/sacred/[!_]*".format(exp_name or "*")):
- exp_name, exp_id = path.parents[1].name, path.name
- runs.append(self._get_experiment(exp_name, exp_id))
- return runs
-
- def _iter_artifacts(self, experiment):
- if experiment is None:
- return []
-
- for fname in experiment["path"].iterdir():
- if fname.suffix == ".pkl" or ".bin" in fname.suffix:
- name, uri = fname.name, str(fname)
- yield name, uri
-
- def _load_data(self, uri):
- with open(uri, "rb") as f:
- data = f.read()
- return data
-
-
-class MongoFetcher(Fetcher):
- """MongoDB Fetcher"""
-
- def __init__(self, mongo_url, db_name):
- self.mongo_url = mongo_url
- self.db_name = db_name
- self.client = None
- self.db = None
- self.runs = None
- self.fs = None
- self._setup_mongo_client()
-
- def _setup_mongo_client(self):
- self.client = pymongo.MongoClient(self.mongo_url)
- self.db = self.client[self.db_name]
- self.runs = self.db.runs
- self.fs = gridfs.GridFS(self.db)
-
- def _get_experiment(self, exp_name, exp_id):
- return self.runs.find_one({"_id": exp_id})
-
- def _list_experiments(self, exp_name=None):
- if exp_name is None:
- return self.runs.find()
- return self.runs.find({"experiment.name": exp_name})
-
- def _iter_artifacts(self, experiment):
- if experiment is None:
- return []
- for artifact in experiment.get("artifacts", []):
- name, uri = artifact["name"], artifact["file_id"]
- yield name, uri
-
- def _load_data(self, uri):
- data = self.fs.get(uri).read()
- return data
-
-
-def create_fetcher_with_config(config_manager: EstimatorConfigManager, load_form_loader: bool = False):
- """Create fetcher with loader config
-
- :param config_manager:
- :param load_form_loader
- :return:
- """
- flag = ""
- if load_form_loader:
- flag = "loader_"
- if config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_FILE_STORAGE:
- return FileFetcher(eval("config_manager.ex_config.{}_dir".format("loader" if load_form_loader else "global")))
- elif config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
- return MongoFetcher(
- mongo_url=eval("config_manager.ex_config.{}mongo_url".format(flag)),
- db_name=eval("config_manager.ex_config.{}db_name".format(flag)),
- )
- else:
- return NotImplementedError("Unkown Backend")
diff --git a/qlib/contrib/estimator/handler.py b/qlib/contrib/estimator/handler.py
deleted file mode 100644
index 3c30b01d8..000000000
--- a/qlib/contrib/estimator/handler.py
+++ /dev/null
@@ -1,585 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-# coding=utf-8
-import abc
-import bisect
-import logging
-
-import pandas as pd
-import numpy as np
-
-from ...log import get_module_logger, TimeInspector
-from ...data import D
-from ...utils import parse_config, transform_end_date
-
-from . import processor as processor_module
-
-
-class BaseDataHandler(abc.ABC):
- def __init__(self, processors=[], **kwargs):
- """
- :param start_date:
- :param end_date:
- :param kwargs:
- """
- # Set logger
- self.logger = get_module_logger("DataHandler")
-
- # init data using kwargs
- self._init_kwargs(**kwargs)
-
- # Setup data.
- self.raw_df, self.feature_names, self.label_names = self._init_raw_df()
-
- # Setup preprocessor
- self.processors = []
- for klass in processors:
- if isinstance(klass, str):
- try:
- klass = getattr(processor_module, klass)
- except:
- raise ValueError("unknown Processor %s" % klass)
- self.processors.append(klass(self.feature_names, self.label_names, **kwargs))
-
- def _init_kwargs(self, **kwargs):
- """
- init the kwargs of DataHandler
- """
- pass
-
- def _init_raw_df(self):
- """
- init raw_df, feature_names, label_names of DataHandler
- if the index of df_feature and df_label are not same, user need to overload this method to merge (e.g. inner, left, right merge).
-
- """
- df_features = self.setup_feature()
- feature_names = df_features.columns
-
- df_labels = self.setup_label()
- label_names = df_labels.columns
-
- raw_df = df_features.merge(df_labels, left_index=True, right_index=True, how="left")
-
- return raw_df, feature_names, label_names
-
- def reset_label(self, df_labels):
- for col in self.label_names:
- del self.raw_df[col]
- self.label_names = df_labels.columns
- self.raw_df = self.raw_df.merge(df_labels, left_index=True, right_index=True, how="left")
-
- def split_rolling_periods(
- self,
- train_start_date,
- train_end_date,
- validate_start_date,
- validate_end_date,
- test_start_date,
- test_end_date,
- rolling_period,
- calendar_freq="day",
- ):
- """
- Calculating the Rolling split periods, the period rolling on market calendar.
- :param train_start_date:
- :param train_end_date:
- :param validate_start_date:
- :param validate_end_date:
- :param test_start_date:
- :param test_end_date:
- :param rolling_period: The market period of rolling
- :param calendar_freq: The frequence of the market calendar
- :yield: Rolling split periods
- """
-
- def get_start_index(calendar, start_date):
- start_index = bisect.bisect_left(calendar, start_date)
- return start_index
-
- def get_end_index(calendar, end_date):
- end_index = bisect.bisect_right(calendar, end_date)
- return end_index - 1
-
- calendar = self.raw_df.index.get_level_values("datetime").unique()
-
- train_start_index = get_start_index(calendar, pd.Timestamp(train_start_date))
- train_end_index = get_end_index(calendar, pd.Timestamp(train_end_date))
- valid_start_index = get_start_index(calendar, pd.Timestamp(validate_start_date))
- valid_end_index = get_end_index(calendar, pd.Timestamp(validate_end_date))
- test_start_index = get_start_index(calendar, pd.Timestamp(test_start_date))
- test_end_index = test_start_index + rolling_period - 1
-
- need_stop_split = False
-
- bound_test_end_index = get_end_index(calendar, pd.Timestamp(test_end_date))
-
- while not need_stop_split:
-
- if test_end_index > bound_test_end_index:
- test_end_index = bound_test_end_index
- need_stop_split = True
-
- yield (
- calendar[train_start_index],
- calendar[train_end_index],
- calendar[valid_start_index],
- calendar[valid_end_index],
- calendar[test_start_index],
- calendar[test_end_index],
- )
-
- train_start_index += rolling_period
- train_end_index += rolling_period
- valid_start_index += rolling_period
- valid_end_index += rolling_period
- test_start_index += rolling_period
- test_end_index += rolling_period
-
- def get_rolling_data(
- self,
- train_start_date,
- train_end_date,
- validate_start_date,
- validate_end_date,
- test_start_date,
- test_end_date,
- rolling_period,
- calendar_freq="day",
- ):
- # Set generator.
- for period in self.split_rolling_periods(
- train_start_date,
- train_end_date,
- validate_start_date,
- validate_end_date,
- test_start_date,
- test_end_date,
- rolling_period,
- calendar_freq,
- ):
- (
- x_train,
- y_train,
- x_validate,
- y_validate,
- x_test,
- y_test,
- ) = self.get_split_data(*period)
- yield x_train, y_train, x_validate, y_validate, x_test, y_test
-
- def get_split_data(
- self,
- train_start_date,
- train_end_date,
- validate_start_date,
- validate_end_date,
- test_start_date,
- test_end_date,
- ):
- """
- all return types are DataFrame
- """
- ## TODO: loc can be slow, expecially when we put it at the second level index.
- if self.raw_df.index.names[0] == "instrument":
- df_train = self.raw_df.loc(axis=0)[:, train_start_date:train_end_date]
- df_validate = self.raw_df.loc(axis=0)[:, validate_start_date:validate_end_date]
- df_test = self.raw_df.loc(axis=0)[:, test_start_date:test_end_date]
- else:
- df_train = self.raw_df.loc[train_start_date:train_end_date]
- df_validate = self.raw_df.loc[validate_start_date:validate_end_date]
- df_test = self.raw_df.loc[test_start_date:test_end_date]
-
- TimeInspector.set_time_mark()
- df_train, df_validate, df_test = self.setup_process_data(df_train, df_validate, df_test)
- TimeInspector.log_cost_time("Finished setup processed data.")
-
- x_train = df_train[self.feature_names]
- y_train = df_train[self.label_names]
-
- x_validate = df_validate[self.feature_names]
- y_validate = df_validate[self.label_names]
-
- x_test = df_test[self.feature_names]
- y_test = df_test[self.label_names]
-
- return x_train, y_train, x_validate, y_validate, x_test, y_test
-
- def setup_process_data(self, df_train, df_valid, df_test):
- """
- process the train, valid and test data
- :return: the processed train, valid and test data.
- """
- for processor in self.processors:
- df_train, df_valid, df_test = processor(df_train, df_valid, df_test)
- return df_train, df_valid, df_test
-
- def get_origin_test_label_with_date(self, test_start_date, test_end_date, freq="day"):
- """Get origin test label
-
- :param test_start_date: test start date
- :param test_end_date: test end date
- :param freq: freq
- :return: pd.DataFrame
- """
- test_end_date = transform_end_date(test_end_date, freq=freq)
- return self.raw_df.loc[(slice(None), slice(test_start_date, test_end_date)), self.label_names]
-
- @abc.abstractmethod
- def setup_feature(self):
- """
- Implement this method to load raw feature.
- the format of the feature is below
- return: df_features
- """
- pass
-
- @abc.abstractmethod
- def setup_label(self):
- """
- Implement this method to load and calculate label.
- the format of the label is below
-
- return: df_label
- """
- pass
-
-
-class QLibDataHandler(BaseDataHandler):
- def __init__(self, start_date, end_date, *args, **kwargs):
- # Dates.
- self.start_date = start_date
- self.end_date = end_date
- super().__init__(*args, **kwargs)
-
- def _init_kwargs(self, **kwargs):
-
- # Instruments
- instruments = kwargs.get("instruments", None)
- if instruments is None:
- market = kwargs.get("market", "csi500").lower()
- data_filter_list = kwargs.get("data_filter_list", list())
- self.instruments = D.instruments(market, filter_pipe=data_filter_list)
- else:
- self.instruments = instruments
-
- # Config of features and labels
- self._fields = kwargs.get("fields", [])
- self._names = kwargs.get("names", [])
- self._labels = kwargs.get("labels", [])
- self._label_names = kwargs.get("label_names", [])
-
- # Check arguments
- assert len(self._fields) > 0, "features list is empty"
- assert len(self._labels) > 0, "labels list is empty"
-
- # Check end_date
- # If test_end_date is -1 or greater than the last date, the last date is used
- self.end_date = transform_end_date(self.end_date)
-
- def setup_feature(self):
- """
- Load the raw data.
- return: df_features
- """
- TimeInspector.set_time_mark()
-
- if len(self._names) == 0:
- names = ["F%d" % i for i in range(len(self._fields))]
- else:
- names = self._names
-
- df_features = D.features(self.instruments, self._fields, self.start_date, self.end_date)
- df_features.columns = names
-
- TimeInspector.log_cost_time("Finished loading features.")
-
- return df_features
-
- def setup_label(self):
- """
- Build up labels in df through users' method
- :return: df_labels
- """
- TimeInspector.set_time_mark()
-
- if len(self._label_names) == 0:
- label_names = ["LABEL%d" % i for i in range(len(self._labels))]
- else:
- label_names = self._label_names
-
- df_labels = D.features(self.instruments, self._labels, self.start_date, self.end_date)
- df_labels.columns = label_names
-
- TimeInspector.log_cost_time("Finished loading labels.")
-
- return df_labels
-
-
-def parse_config_to_fields(config):
- """create factors from config
-
- config = {
- 'kbar': {}, # whether to use some hard-code kbar features
- 'price': { # whether to use raw price features
- 'windows': [0, 1, 2, 3, 4], # use price at n days ago
- 'feature': ['OPEN', 'HIGH', 'LOW'] # which price field to use
- },
- 'volume': { # whether to use raw volume features
- 'windows': [0, 1, 2, 3, 4], # use volume at n days ago
- },
- 'rolling': { # whether to use rolling operator based features
- 'windows': [5, 10, 20, 30, 60], # rolling windows size
- 'include': ['ROC', 'MA', 'STD'], # rolling operator to use
- #if include is None we will use default operators
- 'exclude': ['RANK'], # rolling operator not to use
- }
- }
- """
- fields = []
- names = []
- if "kbar" in config:
- fields += [
- "($close-$open)/$open",
- "($high-$low)/$open",
- "($close-$open)/($high-$low+1e-12)",
- "($high-Greater($open, $close))/$open",
- "($high-Greater($open, $close))/($high-$low+1e-12)",
- "(Less($open, $close)-$low)/$open",
- "(Less($open, $close)-$low)/($high-$low+1e-12)",
- "(2*$close-$high-$low)/$open",
- "(2*$close-$high-$low)/($high-$low+1e-12)",
- ]
- names += [
- "KMID",
- "KLEN",
- "KMID2",
- "KUP",
- "KUP2",
- "KLOW",
- "KLOW2",
- "KSFT",
- "KSFT2",
- ]
- if "price" in config:
- windows = config["price"].get("windows", range(5))
- feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"])
- for field in feature:
- field = field.lower()
- fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows]
- names += [field.upper() + str(d) for d in windows]
- if "volume" in config:
- windows = config["volume"].get("windows", range(5))
- fields += ["Ref($volume, %d)/$volume" % d if d != 0 else "$volume/$volume" for d in windows]
- names += ["VOLUME" + str(d) for d in windows]
- if "rolling" in config:
- windows = config["rolling"].get("windows", [5, 10, 20, 30, 60])
- include = config["rolling"].get("include", None)
- exclude = config["rolling"].get("exclude", [])
- # `exclude` in dataset config unnecessary filed
- # `include` in dataset config necessary field
- use = lambda x: x not in exclude and (include is None or x in include)
- if use("ROC"):
- fields += ["Ref($close, %d)/$close" % d for d in windows]
- names += ["ROC%d" % d for d in windows]
- if use("MA"):
- fields += ["Mean($close, %d)/$close" % d for d in windows]
- names += ["MA%d" % d for d in windows]
- if use("STD"):
- fields += ["Std($close, %d)/$close" % d for d in windows]
- names += ["STD%d" % d for d in windows]
- if use("BETA"):
- fields += ["Slope($close, %d)/$close" % d for d in windows]
- names += ["BETA%d" % d for d in windows]
- if use("RSQR"):
- fields += ["Rsquare($close, %d)" % d for d in windows]
- names += ["RSQR%d" % d for d in windows]
- if use("RESI"):
- fields += ["Resi($close, %d)/$close" % d for d in windows]
- names += ["RESI%d" % d for d in windows]
- if use("MAX"):
- fields += ["Max($high, %d)/$close" % d for d in windows]
- names += ["MAX%d" % d for d in windows]
- if use("LOW"):
- fields += ["Min($low, %d)/$close" % d for d in windows]
- names += ["MIN%d" % d for d in windows]
- if use("QTLU"):
- fields += ["Quantile($close, %d, 0.8)/$close" % d for d in windows]
- names += ["QTLU%d" % d for d in windows]
- if use("QTLD"):
- fields += ["Quantile($close, %d, 0.2)/$close" % d for d in windows]
- names += ["QTLD%d" % d for d in windows]
- if use("RANK"):
- fields += ["Rank($close, %d)" % d for d in windows]
- names += ["RANK%d" % d for d in windows]
- if use("RSV"):
- fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows]
- names += ["RSV%d" % d for d in windows]
- if use("IMAX"):
- fields += ["IdxMax($high, %d)/%d" % (d, d) for d in windows]
- names += ["IMAX%d" % d for d in windows]
- if use("IMIN"):
- fields += ["IdxMin($low, %d)/%d" % (d, d) for d in windows]
- names += ["IMIN%d" % d for d in windows]
- if use("IMXD"):
- fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows]
- names += ["IMXD%d" % d for d in windows]
- if use("CORR"):
- fields += ["Corr($close, Log($volume+1), %d)" % d for d in windows]
- names += ["CORR%d" % d for d in windows]
- if use("CORD"):
- fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows]
- names += ["CORD%d" % d for d in windows]
- if use("CNTP"):
- fields += ["Mean($close>Ref($close, 1), %d)" % d for d in windows]
- names += ["CNTP%d" % d for d in windows]
- if use("CNTN"):
- fields += ["Mean($close][Ref($close, 1), %d)-Mean($close][= -3, -3 - (x + 3).div(x.min() + 3) * 0.5, inplace=True)
- if self.fillna_feature:
- x.fillna(0, inplace=True)
- return x
-
- TimeInspector.set_time_mark()
-
- # Copy
- df_new = df.copy()
-
- # Label
- cols = df.columns[df.columns.str.contains("^LABEL")]
- df_new[cols] = df[cols].groupby(level="datetime").apply(_label_norm)
-
- # Features
- cols = df.columns[df.columns.str.contains("^KLEN|^KLOW|^KUP")]
- df_new[cols] = df[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
-
- cols = df.columns[df.columns.str.contains("^KLOW2|^KUP2")]
- df_new[cols] = df[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
-
- _cols = [
- "KMID",
- "KSFT",
- "OPEN",
- "HIGH",
- "LOW",
- "CLOSE",
- "VWAP",
- "ROC",
- "MA",
- "BETA",
- "RESI",
- "QTLU",
- "QTLD",
- "RSV",
- "SUMP",
- "SUMN",
- "SUMD",
- "VSUMP",
- "VSUMN",
- "VSUMD",
- ]
- pat = "|".join(["^" + x for x in _cols])
- cols = df.columns[df.columns.str.contains(pat) & (~df.columns.isin(["HIGH0", "LOW0"]))]
- df_new[cols] = df[cols].groupby(level="datetime").apply(_feature_norm)
-
- cols = df.columns[df.columns.str.contains("^STD|^VOLUME|^VMA|^VSTD")]
- df_new[cols] = df[cols].apply(np.log).groupby(level="datetime").apply(_feature_norm)
-
- cols = df.columns[df.columns.str.contains("^RSQR")]
- df_new[cols] = df[cols].fillna(0).groupby(level="datetime").apply(_feature_norm)
-
- cols = df.columns[df.columns.str.contains("^MAX|^HIGH0")]
- df_new[cols] = df[cols].apply(lambda x: (x - 1) ** 0.5).groupby(level="datetime").apply(_feature_norm)
-
- cols = df.columns[df.columns.str.contains("^MIN|^LOW0")]
- df_new[cols] = df[cols].apply(lambda x: (1 - x) ** 0.5).groupby(level="datetime").apply(_feature_norm)
-
- cols = df.columns[df.columns.str.contains("^CORR|^CORD")]
- df_new[cols] = df[cols].apply(np.exp).groupby(level="datetime").apply(_feature_norm)
-
- cols = df.columns[df.columns.str.contains("^WVMA")]
- df_new[cols] = df[cols].apply(np.log1p).groupby(level="datetime").apply(_feature_norm)
-
- TimeInspector.log_cost_time("Finished preprocessing data.")
-
- return df_new
diff --git a/qlib/contrib/estimator/trainer.py b/qlib/contrib/estimator/trainer.py
deleted file mode 100644
index 6cb57f702..000000000
--- a/qlib/contrib/estimator/trainer.py
+++ /dev/null
@@ -1,317 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-# coding=utf-8
-
-from abc import abstractmethod
-
-import pandas as pd
-import numpy as np
-from scipy.stats import pearsonr
-
-from ...log import get_module_logger, TimeInspector
-from .handler import BaseDataHandler
-from .launcher import CONFIG_MANAGER
-from .fetcher import create_fetcher_with_config
-from ...utils import drop_nan_by_y_index, transform_end_date
-
-
-class BaseTrainer(object):
- def __init__(self, model_class, model_save_path, model_args, data_handler: BaseDataHandler, sacred_ex, **kwargs):
- # 1. Model.
- self.model_class = model_class
- self.model_save_path = model_save_path
- self.model_args = model_args
-
- # 2. Data handler.
- self.data_handler = data_handler
-
- # 3. Sacred ex.
- self.ex = sacred_ex
-
- # 4. Logger.
- self.logger = get_module_logger("Trainer")
-
- # 5. Data time
- self.train_start_date = kwargs.get("train_start_date", None)
- self.train_end_date = kwargs.get("train_end_date", None)
- self.validate_start_date = kwargs.get("validate_start_date", None)
- self.validate_end_date = kwargs.get("validate_end_date", None)
- self.test_start_date = kwargs.get("test_start_date", None)
- self.test_end_date = transform_end_date(kwargs.get("test_end_date", None))
-
- @abstractmethod
- def train(self):
- """
- Implement this method indicating how to train a model.
- """
- pass
-
- @abstractmethod
- def load(self):
- """
- Implement this method indicating how to restore a model and the data.
- """
- pass
-
- @abstractmethod
- def get_test_pred(self):
- """
- Implement this method indicating how to get prediction result(s) from a model.
- """
- pass
-
- def get_test_performance(self):
- """
- Implement this method indicating how to get the performance of the model.
- """
- raise NotImplementedError(f"Please implement `get_test_performance`")
-
- def get_test_score(self):
- """
- Override this method to transfer the predict result(s) into the score of the stock.
- Note: If this is a multi-label training, you need to transfer predict labels into one score.
- Or you can just use the result of `get_test_pred()` (you can also process the result) if this is one label training.
- We use the first column of the result of `get_test_pred()` as default method (regard it as one label training).
- """
- pred = self.get_test_pred()
- pred_score = pd.DataFrame(index=pred.index)
- pred_score["score"] = pred.iloc(axis=1)[0]
- return pred_score
-
-
-class StaticTrainer(BaseTrainer):
- def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
- super(StaticTrainer, self).__init__(model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs)
- self.model = None
-
- split_data = self.data_handler.get_split_data(
- self.train_start_date,
- self.train_end_date,
- self.validate_start_date,
- self.validate_end_date,
- self.test_start_date,
- self.test_end_date,
- )
- (
- self.x_train,
- self.y_train,
- self.x_validate,
- self.y_validate,
- self.x_test,
- self.y_test,
- ) = split_data
-
- def train(self):
- TimeInspector.set_time_mark()
- model = self.model_class(**self.model_args)
-
- if CONFIG_MANAGER.ex_config.finetune:
- fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
- loader_model = fetcher.get_experiment(
- exp_name=CONFIG_MANAGER.ex_config.loader_name,
- exp_id=CONFIG_MANAGER.ex_config.loader_id,
- fields=["model"],
- )["model"]
-
- if isinstance(loader_model, list):
- model_index = (
- -1
- if CONFIG_MANAGER.ex_config.loader_model_index is None
- else CONFIG_MANAGER.ex_config.loader_model_index
- )
- loader_model = loader_model[model_index]
-
- model.load(loader_model)
- model.finetune(self.x_train, self.y_train, self.x_validate, self.y_validate)
- else:
- model.fit(self.x_train, self.y_train, self.x_validate, self.y_validate)
- model.save(self.model_save_path)
- self.ex.add_artifact(self.model_save_path)
- self.model = model
- TimeInspector.log_cost_time("Finished training model.")
-
- def load(self):
- model = self.model_class(**self.model_args)
-
- # Load model
- fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
- loader_model = fetcher.get_experiment(
- exp_name=CONFIG_MANAGER.ex_config.loader_name,
- exp_id=CONFIG_MANAGER.ex_config.loader_id,
- fields=["model"],
- )["model"]
-
- if isinstance(loader_model, list):
- model_index = (
- -1
- if CONFIG_MANAGER.ex_config.loader_model_index is None
- else CONFIG_MANAGER.ex_config.loader_model_index
- )
- loader_model = loader_model[model_index]
-
- model.load(loader_model)
-
- # Save model, after load, if you don't save the model, the result of this experiment will be no model
- model.save(self.model_save_path)
- self.ex.add_artifact(self.model_save_path)
- self.model = model
-
- def get_test_pred(self):
- pred = self.model.predict(self.x_test)
- pred = pd.DataFrame(pred, index=self.x_test.index, columns=self.y_test.columns)
- return pred
-
- def get_test_performance(self):
- try:
- model_score = self.model.score(self.x_test, self.y_test)
- except NotImplementedError:
- model_score = None
- # Remove rows from x, y and w, which contain Nan in any columns in y_test.
- x_test, y_test, __ = drop_nan_by_y_index(self.x_test, self.y_test)
- pred_test = self.model.predict(x_test)
- model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
-
- performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
- return performance
-
-
-class RollingTrainer(BaseTrainer):
- def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
- super(RollingTrainer, self).__init__(
- model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs
- )
- self.rolling_period = kwargs.get("rolling_period", 60)
- self.models = []
- self.rolling_data = []
- self.all_x_test = []
- self.all_y_test = []
- for data in self.data_handler.get_rolling_data(
- self.train_start_date,
- self.train_end_date,
- self.validate_start_date,
- self.validate_end_date,
- self.test_start_date,
- self.test_end_date,
- self.rolling_period,
- ):
- self.rolling_data.append(data)
- __, __, __, __, x_test, y_test = data
- self.all_x_test.append(x_test)
- self.all_y_test.append(y_test)
-
- def train(self):
- # 1. Get total data parts.
- # total_data_parts = self.data_handler.total_data_parts
- # self.logger.warning('Total numbers of model are: {}, start training models...'.format(total_data_parts))
- if CONFIG_MANAGER.ex_config.finetune:
- fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
- loader_model = fetcher.get_experiment(
- exp_name=CONFIG_MANAGER.ex_config.loader_name,
- exp_id=CONFIG_MANAGER.ex_config.loader_id,
- fields=["model"],
- )["model"]
- loader_model_index = CONFIG_MANAGER.ex_config.loader_model_index
- previous_model_path = ""
- # 2. Rolling train.
- for (
- index,
- (x_train, y_train, x_validate, y_validate, x_test, y_test),
- ) in enumerate(self.rolling_data):
- TimeInspector.set_time_mark()
- model = self.model_class(**self.model_args)
-
- if CONFIG_MANAGER.ex_config.finetune:
- # Finetune model
- if loader_model_index is None and isinstance(loader_model, list):
- try:
- model.load(loader_model[index])
- except IndexError:
- # Load model by previous_model_path
- with open(previous_model_path, "rb") as fp:
- model.load(fp)
- model.finetune(x_train, y_train, x_validate, y_validate)
- else:
-
- if index == 0:
- loader_model = (
- loader_model[loader_model_index] if isinstance(loader_model, list) else loader_model
- )
- model.load(loader_model)
- else:
- with open(previous_model_path, "rb") as fp:
- model.load(fp)
-
- model.finetune(x_train, y_train, x_validate, y_validate)
-
- else:
- model.fit(x_train, y_train, x_validate, y_validate)
-
- model_save_path = "{}_{}".format(self.model_save_path, index)
- model.save(model_save_path)
- previous_model_path = model_save_path
- self.ex.add_artifact(model_save_path)
- self.models.append(model)
- TimeInspector.log_cost_time("Finished training model: {}.".format(index + 1))
-
- def load(self):
- """
- Load the data and the model
- """
- fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
- loader_model = fetcher.get_experiment(
- exp_name=CONFIG_MANAGER.ex_config.loader_name,
- exp_id=CONFIG_MANAGER.ex_config.loader_id,
- fields=["model"],
- )["model"]
- for index in range(len(self.all_x_test)):
- model = self.model_class(**self.model_args)
-
- model.load(loader_model[index])
-
- # Save model
- model_save_path = "{}_{}".format(self.model_save_path, index)
- model.save(model_save_path)
- self.ex.add_artifact(model_save_path)
-
- self.models.append(model)
-
- def get_test_pred(self):
- """
- Predict the score on test data with the models.
- Please ensure the models and data are loaded before call this score.
-
- :return: the predicted scores for the pred
- """
- pred_df_list = []
- y_test_columns = self.all_y_test[0].columns
- # Start iteration.
- for model, x_test in zip(self.models, self.all_x_test):
- pred = model.predict(x_test)
- pred_df = pd.DataFrame(pred, index=x_test.index, columns=y_test_columns)
- pred_df_list.append(pred_df)
- return pd.concat(pred_df_list)
-
- def get_test_performance(self):
- """
- Get the performances of the models
-
- :return: the performances of models
- """
- pred_test_list = []
- y_test_list = []
- scorer = self.models[0]._scorer
- for model, x_test, y_test in zip(self.models, self.all_x_test, self.all_y_test):
- # Remove rows from x, y and w, which contain Nan in any columns in y_test.
- x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
- pred_test_list.append(model.predict(x_test))
- y_test_list.append(np.squeeze(y_test.values))
-
- pred_test_array = np.concatenate(pred_test_list, axis=0)
- y_test_array = np.concatenate(y_test_list, axis=0)
-
- model_score = scorer(y_test_array, pred_test_array)
- model_pearsonr = pearsonr(np.ravel(y_test_array), np.ravel(pred_test_array))[0]
-
- performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
- return performance
diff --git a/qlib/contrib/eva/__init__.py b/qlib/contrib/eva/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py
new file mode 100644
index 000000000..c68571853
--- /dev/null
+++ b/qlib/contrib/eva/alpha.py
@@ -0,0 +1,76 @@
+"""
+Here is a batch of evaluation functions.
+
+The interface should be redesigned carefully in the future.
+"""
+import pandas as pd
+
+from typing import Tuple
+
+
+def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
+ """calc_ic.
+
+ Parameters
+ ----------
+ pred :
+ pred
+ label :
+ label
+ date_col :
+ date_col
+
+ Returns
+ -------
+ (pd.Series, pd.Series)
+ ic and rank ic
+ """
+ df = pd.DataFrame({"pred": pred, "label": label})
+ ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
+ ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
+ if dropna:
+ return ic.dropna(), ric.dropna()
+ else:
+ return ic, ric
+
+
+def calc_long_short_return(
+ pred: pd.Series,
+ label: pd.Series,
+ date_col: str = "datetime",
+ quantile: float = 0.2,
+ dropna: bool = False,
+) -> Tuple[pd.Series, pd.Series]:
+ """
+ calculate long-short return
+
+ Note:
+ `label` must be raw stock returns.
+
+ Parameters
+ ----------
+ pred : pd.Series
+ stock predictions
+ label : pd.Series
+ stock returns
+ date_col : str
+ datetime index name
+ quantile : float
+ long-short quantile
+
+ Returns
+ ----------
+ long_short_r : pd.Series
+ daily long-short returns
+ long_avg_r : pd.Series
+ daily long-average returns
+ """
+ df = pd.DataFrame({"pred": pred, "label": label})
+ if dropna:
+ df.dropna(inplace=True)
+ group = df.groupby(level=date_col)
+ N = lambda x: int(len(x) * quantile)
+ r_long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label.mean())
+ r_short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label.mean())
+ r_avg = group.label.mean()
+ return (r_long - r_short) / 2, r_avg
diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py
index 8c427c16e..4bb5e4372 100644
--- a/qlib/contrib/evaluate.py
+++ b/qlib/contrib/evaluate.py
@@ -15,6 +15,7 @@ from .backtest.backtest import backtest as backtest_func, get_date_range
from ..data import D
from ..config import C
+from ..data.dataset.utils import get_level_index
logger = get_module_logger("Evaluate")
@@ -25,9 +26,9 @@ def risk_analysis(r, N=252):
Parameters
----------
r : pandas.Series
- daily return series
+ daily return series.
N: int
- scaler for annualizing information_ratio (day: 250, week: 50, month: 12)
+ scaler for annualizing information_ratio (day: 250, week: 50, month: 12).
"""
mean = r.mean()
std = r.std(ddof=1)
@@ -60,22 +61,26 @@ def get_strategy(
----------
strategy : Strategy()
- strategy used in backtest
+ strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
+ - if isinstance(margin, int):
+
sell_limit = margin
- else:
+
+ - else:
+
sell_limit = pred_in_a_day.count() * margin
- buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit)
- sell_limit should be no less than topk
+
+ buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
+ sell_limit should be no less than topk.
n_drop : int
- number of stocks to be replaced in each trading date
+ number of stocks to be replaced in each trading date.
risk_degree: float
- 0-1, 0.95 for example, use 95% money to trade
+ 0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
- strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy
+ strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
Returns
-------
@@ -121,21 +126,21 @@ def get_exchange(
----------
# exchange related arguments
- exchange: Exchange()
+ exchange: Exchange().
subscribe_fields: list
- subscribe fields
+ subscribe fields.
open_cost : float
- open transaction cost
+ open transaction cost.
close_cost : float
- close transaction cost
+ close transaction cost.
min_cost : float
- min transaction cost
+ min transaction cost.
trade_unit : int
- 100 for China A
+ 100 for China A.
deal_price: str
- dealing price type: 'close', 'open', 'vwap'
+ dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
- limit move 0.1 (10%) for example, long and short with same limit
+ limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
@@ -158,11 +163,11 @@ def get_exchange(
if deal_price[0] != "$":
deal_price = "$" + deal_price
if extract_codes:
- codes = sorted(pred.index.get_level_values(0).unique())
+ codes = sorted(pred.index.get_level_values("instrument").unique())
else:
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
- dates = sorted(pred.index.get_level_values(1).unique())
+ dates = sorted(pred.index.get_level_values("datetime").unique())
dates = np.append(dates, get_date_range(dates[-1], shift=shift))
exchange = Exchange(
@@ -185,54 +190,61 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
Parameters
----------
- # backtest workflow related or commmon arguments
- pred : pandas.DataFrame
- predict should has index and one `score` column
- account : float
- init account value
- shift : int
- whether to shift prediction by one day
- benchmark : str
- benchmark code, default is SH000905 CSI 500
- verbose : bool
- whether to print log
+ - **backtest workflow related or commmon arguments**
+
+ pred : pandas.DataFrame
+ predict should has index and one `score` column.
+ account : float
+ init account value.
+ shift : int
+ whether to shift prediction by one day.
+ benchmark : str
+ benchmark code, default is SH000905 CSI 500.
+ verbose : bool
+ whether to print log.
+
+ - **strategy related arguments**
- # strategy related arguments
strategy : Strategy()
- strategy used in backtest
+ strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
- sell_limit = margin
- else:
- sell_limit = pred_in_a_day.count() * margin
- buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit)
- sell_limit should be no less than topk
- n_drop : int
- number of stocks to be replaced in each trading date
- risk_degree: float
- 0-1, 0.95 for example, use 95% money to trade
- str_type: 'amount', 'weight' or 'dropout'
- strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy
+ - if isinstance(margin, int):
+
+ sell_limit = margin
+
+ - else:
+
+ sell_limit = pred_in_a_day.count() * margin
+
+ buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
+ sell_limit should be no less than topk.
+ n_drop : int
+ number of stocks to be replaced in each trading date.
+ risk_degree: float
+ 0-1, 0.95 for example, use 95% money to trade.
+ str_type: 'amount', 'weight' or 'dropout'
+ strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
+
+ - **exchange related arguments**
- # exchange related arguments
exchange: Exchange()
pass the exchange for speeding up.
subscribe_fields: list
- subscribe fields
+ subscribe fields.
open_cost : float
open transaction cost. The default value is 0.002(0.2%).
close_cost : float
close transaction cost. The default value is 0.002(0.2%).
min_cost : float
- min transaction cost
+ min transaction cost.
trade_unit : int
- 100 for China A
+ 100 for China A.
deal_price: str
- dealing price type: 'close', 'open', 'vwap'
+ dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
- limit move 0.1 (10%) for example, long and short with same limit
+ limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
@@ -279,17 +291,17 @@ def long_short_backtest(
"""
A backtest for long-short strategy
- :param pred: The trading signal produced on day `T`
- :param topk: The short topk securities and long topk securities
- :param deal_price: The price to deal the trading
+ :param pred: The trading signal produced on day `T`.
+ :param topk: The short topk securities and long topk securities.
+ :param deal_price: The price to deal the trading.
:param shift: Whether to shift prediction by one day. The trading day will be T+1 if shift==1.
- :param open_cost: open transaction cost
- :param close_cost: close transaction cost
- :param trade_unit: 100 for China A
- :param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit
- :param min_cost: min transaction cost
- :param subscribe_fields: subscribe fields
- :param extract_codes: bool
+ :param open_cost: open transaction cost.
+ :param close_cost: close transaction cost.
+ :param trade_unit: 100 for China A.
+ :param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit.
+ :param min_cost: min transaction cost.
+ :param subscribe_fields: subscribe fields.
+ :param extract_codes: bool.
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
:return: The result of backtest, it is represented by a dict.
@@ -297,6 +309,8 @@ def long_short_backtest(
"short": short_returns(excess),
"long_short": long_short_returns}
"""
+ if get_level_index(pred, level="datetime") == 1:
+ pred = pred.swaplevel().sort_index()
if trade_unit is None:
trade_unit = C.trade_unit
@@ -333,13 +347,13 @@ def long_short_backtest(
ls_returns = {}
for pdate, date in zip(predict_dates, trade_dates):
- score = pred.loc(axis=0)[:, pdate]
+ score = pred.loc(axis=0)[pdate, :]
score = score.reset_index().sort_values(by="score", ascending=False)
long_stocks = list(score.iloc[:topk]["instrument"])
short_stocks = list(score.iloc[-topk:]["instrument"])
- score = score.set_index(["instrument", "datetime"]).sort_index()
+ score = score.set_index(["datetime", "instrument"]).sort_index()
long_profit = []
short_profit = []
@@ -363,7 +377,7 @@ def long_short_backtest(
else:
short_profit.append(-profit)
- for stock in list(score.loc(axis=0)[:, pdate].index.get_level_values(level=0)):
+ for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
# exclude the suspend stock
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
continue
diff --git a/qlib/contrib/model/__init__.py b/qlib/contrib/model/__init__.py
index c639b57f5..e69de29bb 100644
--- a/qlib/contrib/model/__init__.py
+++ b/qlib/contrib/model/__init__.py
@@ -1,6 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import warnings
-
-from .base import Model
diff --git a/qlib/contrib/model/base.py b/qlib/contrib/model/base.py
deleted file mode 100644
index b3ea917a5..000000000
--- a/qlib/contrib/model/base.py
+++ /dev/null
@@ -1,155 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import six
-
-
-@six.add_metaclass(abc.ABCMeta)
-class Model(object):
- """Model base class"""
-
- @property
- def name(self):
- return type(self).__name__
-
- def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
- """fix train with cross-validation
- Fit model when ex_config.finetune is False
-
- Parameters
- ----------
- x_train : pd.dataframe
- train data
- y_train : pd.dataframe
- train label
- x_valid : pd.dataframe
- valid data
- y_valid : pd.dataframe
- valid label
- w_train : pd.dataframe
- train weight
- w_valid : pd.dataframe
- valid weight
-
- Returns
- ----------
- Model
- trained model
- """
- raise NotImplementedError()
-
- def score(self, x_test, y_test, w_test=None, **kwargs):
- """evaluate model with test data/label
-
- Parameters
- ----------
- x_test : pd.dataframe
- test data
- y_test : pd.dataframe
- test label
- w_test : pd.dataframe
- test weight
-
- Returns
- ----------
- float
- evaluation score
- """
- raise NotImplementedError()
-
- def predict(self, x_test, **kwargs):
- """predict given test data
-
- Parameters
- ----------
- x_test : pd.dataframe
- test data
-
- Returns
- ----------
- np.ndarray
- test predict label
- """
- raise NotImplementedError()
-
- def save(self, fname, **kwargs):
- """save model
-
- Parameters
- ----------
- fname : str
- model filename
- """
- # TODO: Currently need to save the model as a single file, otherwise the estimator may not be compatible
- raise NotImplementedError()
-
- def load(self, buffer, **kwargs):
- """load model
-
- Parameters
- ----------
- buffer : bytes
- binary data of model parameters
-
- Returns
- ----------
- Model
- loaded model
- """
- raise NotImplementedError()
-
- def get_data_with_date(self, date, **kwargs):
- """
- Will be called in online module
- need to return the data that used to predict the label (score) of stocks at date.
-
- :param
- date: pd.Timestamp
- predict date
- :return:
- data: the input data that used to predict the label (score) of stocks at predict date.
- """
- raise NotImplementedError("get_data_with_date for this model is not implemented.")
-
- def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
- """Finetune model
- In `RollingTrainer`:
- if loader.model_index is None:
- If provide 'Static Model', based on the provided 'Static' model update.
- If provide 'Rolling Model', skip the model of load, based on the last 'provided model' update.
-
- if loader.model_index is not None:
- Based on the provided model(loader.model_index) update.
-
- In `StaticTrainer`:
- If the load is 'static model':
- Based on the 'static model' update
- If the load is 'rolling model':
- Based on the provided model(`loader.model_index`) update. If `loader.model_index` is None, use the last model.
-
- Parameters
- ----------
- x_train : pd.dataframe
- train data
- y_train : pd.dataframe
- train label
- x_valid : pd.dataframe
- valid data
- y_valid : pd.dataframe
- valid label
- w_train : pd.dataframe
- train weight
- w_valid : pd.dataframe
- valid weight
-
- Returns
- ----------
- Model
- finetune model
- """
- raise NotImplementedError("Finetune for this model is not implemented.")
diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py
new file mode 100644
index 000000000..d57c32b70
--- /dev/null
+++ b/qlib/contrib/model/catboost_model.py
@@ -0,0 +1,73 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import numpy as np
+import pandas as pd
+from catboost import Pool, CatBoost
+from catboost.utils import get_gpu_device_count
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class CatBoostModel(Model):
+ """CatBoost Model"""
+
+ def __init__(self, loss="RMSE", **kwargs):
+ # There are more options
+ if loss not in {"RMSE", "Logloss"}:
+ raise NotImplementedError
+ self._params = {"loss_function": loss}
+ self._params.update(kwargs)
+ self.model = None
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ num_boost_round=1000,
+ early_stopping_rounds=50,
+ verbose_eval=20,
+ evals_result=dict(),
+ **kwargs
+ ):
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ # CatBoost needs 1D array as its label
+ if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
+ y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
+ else:
+ raise ValueError("CatBoost doesn't support multi-label training")
+
+ train_pool = Pool(data=x_train, label=y_train_1d)
+ valid_pool = Pool(data=x_valid, label=y_valid_1d)
+
+ # Initialize the catboost model
+ self._params["iterations"] = num_boost_round
+ self._params["early_stopping_rounds"] = early_stopping_rounds
+ self._params["verbose_eval"] = verbose_eval
+ self._params["task_type"] = "GPU" if get_gpu_device_count() > 0 else "CPU"
+ self.model = CatBoost(self._params, **kwargs)
+
+ # train the model
+ self.model.fit(train_pool, eval_set=valid_pool, use_best_model=True, **kwargs)
+
+ evals_result = self.model.get_evals_result()
+ evals_result["train"] = list(evals_result["learn"].values())[0]
+ evals_result["valid"] = list(evals_result["validation"].values())[0]
+
+ def predict(self, dataset):
+ if self.model is None:
+ raise ValueError("model is not fitted yet!")
+ x_test = dataset.prepare("test", col_set="feature")
+ return pd.Series(self.model.predict(x_test.values), index=x_test.index)
+
+
+if __name__ == "__main__":
+ cat = CatBoostModel()
diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py
index 61b902995..058d9a0e3 100644
--- a/qlib/contrib/model/gbdt.py
+++ b/qlib/contrib/model/gbdt.py
@@ -1,63 +1,54 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-
-from __future__ import division
-from __future__ import print_function
-
import numpy as np
+import pandas as pd
import lightgbm as lgb
-from sklearn.metrics import roc_auc_score, mean_squared_error
-from .base import Model
-from ...utils import drop_nan_by_y_index
+from ...model.base import ModelFT
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
-class LGBModel(Model):
- """LightGBM Model
-
- Parameters
- ----------
- param_update : dict
- training parameters
- """
-
- _params = dict()
+class LGBModel(ModelFT):
+ """LightGBM Model"""
def __init__(self, loss="mse", **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
- self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
- self._params.update(objective=loss, **kwargs)
- self._model = None
+ self.params = {"objective": loss, "verbosity": -1}
+ self.params.update(kwargs)
+ self.model = None
+
+ def _prepare_data(self, dataset: DatasetH):
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ # Lightgbm need 1D array as its label
+ if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
+ y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
+ else:
+ raise ValueError("LightGBM doesn't support multi-label training")
+
+ dtrain = lgb.Dataset(x_train.values, label=y_train)
+ dvalid = lgb.Dataset(x_valid.values, label=y_valid)
+ return dtrain, dvalid
def fit(
self,
- x_train,
- y_train,
- x_valid,
- y_valid,
- w_train=None,
- w_valid=None,
+ dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs
):
- # Lightgbm need 1D array as its label
- if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
- y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
- else:
- raise ValueError("LightGBM doesn't support multi-label training")
-
- w_train_weight = None if w_train is None else w_train.values
- w_valid_weight = None if w_valid is None else w_valid.values
-
- dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
- dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
- self._model = lgb.train(
- self._params,
+ dtrain, dvalid = self._prepare_data(dataset)
+ self.model = lgb.train(
+ self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
@@ -70,22 +61,33 @@ class LGBModel(Model):
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
- def predict(self, x_test):
- if self._model is None:
+ def predict(self, dataset):
+ if self.model is None:
raise ValueError("model is not fitted yet!")
- return self._model.predict(x_test.values)
+ x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
+ return pd.Series(self.model.predict(x_test.values), index=x_test.index)
- def score(self, x_test, y_test, w_test=None):
- # Remove rows from x, y and w, which contain Nan in any columns in y_test.
- x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
- preds = self.predict(x_test)
- w_test_weight = None if w_test is None else w_test.values
- return self._scorer(y_test.values, preds, sample_weight=w_test_weight)
+ def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
+ """
+ finetune model
- def save(self, filename):
- if self._model is None:
- raise ValueError("model is not fitted yet!")
- self._model.save_model(filename)
-
- def load(self, buffer):
- self._model = lgb.Booster(params={"model_str": buffer.decode("utf-8")})
+ Parameters
+ ----------
+ dataset : DatasetH
+ dataset for finetuning
+ num_boost_round : int
+ number of round to finetune model
+ verbose_eval : int
+ verbose level
+ """
+ # Based on existing model and finetune by train more rounds
+ dtrain, _ = self._prepare_data(dataset)
+ self.model = lgb.train(
+ self.params,
+ dtrain,
+ num_boost_round=num_boost_round,
+ init_model=self.model,
+ valid_sets=[dtrain],
+ valid_names=["train"],
+ verbose_eval=verbose_eval,
+ )
diff --git a/qlib/contrib/model/linear.py b/qlib/contrib/model/linear.py
new file mode 100644
index 000000000..0f9223737
--- /dev/null
+++ b/qlib/contrib/model/linear.py
@@ -0,0 +1,91 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import numpy as np
+import pandas as pd
+
+from scipy.optimize import nnls
+from sklearn.linear_model import LinearRegression, Ridge, Lasso
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class LinearModel(Model):
+ """Linear Model
+
+ Solve one of the following regression problems:
+ - `ols`: min_w |y - Xw|^2_2
+ - `nnls`: min_w |y - Xw|^2_2, s.t. w >= 0
+ - `ridge`: min_w |y - Xw|^2_2 + \alpha*|w|^2_2
+ - `lasso`: min_w |y - Xw|^2_2 + \alpha*|w|_1
+ where `w` is the regression coefficient.
+ """
+
+ OLS = "ols"
+ NNLS = "nnls"
+ RIDGE = "ridge"
+ LASSO = "lasso"
+
+ def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False):
+ """
+ Parameters
+ ----------
+ estimator : str
+ which estimator to use for linear regression
+ alpha : float
+ l1 or l2 regularization parameter
+ fit_intercept : bool
+ whether fit intercept
+ """
+ assert estimator in [self.OLS, self.NNLS, self.RIDGE, self.LASSO], f"unsupported estimator `{estimator}`"
+ self.estimator = estimator
+
+ assert alpha == 0 or estimator in [self.RIDGE, self.LASSO], f"alpha is only supported in `ridge`&`lasso`"
+ self.alpha = alpha
+
+ self.fit_intercept = fit_intercept
+
+ self.coef_ = None
+
+ def fit(self, dataset: DatasetH):
+ df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
+ X, y = df_train["feature"].values, np.squeeze(df_train["label"].values)
+
+ if self.estimator in [self.OLS, self.RIDGE, self.LASSO]:
+ self._fit(X, y)
+ elif self.estimator == self.NNLS:
+ self._fit_nnls(X, y)
+ else:
+ raise ValueError(f"unknown estimator `{self.estimator}`")
+
+ return self
+
+ def _fit(self, X, y):
+ if self.estimator == self.OLS:
+ model = LinearRegression(fit_intercept=self.fit_intercept, copy_X=False)
+ else:
+ model = {self.RIDGE: Ridge, self.LASSO: Lasso}[self.estimator](
+ alpha=self.alpha, fit_intercept=self.fit_intercept, copy_X=False
+ )
+ model.fit(X, y)
+ self.coef_ = model.coef_
+ self.intercept_ = model.intercept_
+
+ def _fit_nnls(self, X, y):
+ if self.fit_intercept:
+ X = np.c_[X, np.ones(len(X))] # NOTE: mem copy
+ coef = nnls(X, y)[0]
+ if self.fit_intercept:
+ self.coef_ = coef[:-1]
+ self.intercept_ = coef[-1]
+ else:
+ self.coef_ = coef
+ self.intercept_ = 0.0
+
+ def predict(self, dataset):
+ if self.coef_ is None:
+ raise ValueError("model is not fitted yet!")
+ x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
+ return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)
diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py
new file mode 100644
index 000000000..40c2f8226
--- /dev/null
+++ b/qlib/contrib/model/pytorch_alstm.py
@@ -0,0 +1,364 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import pandas as pd
+import copy
+from sklearn.metrics import roc_auc_score, mean_squared_error
+import logging
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
+from ...log import get_module_logger, TimeInspector
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class ALSTM(Model):
+ """ALSTM Model
+
+ Parameters
+ ----------
+ d_feat : int
+ input dimension for each time step
+ metric: str
+ the evaluate metric used in early stop
+ optimizer : str
+ optimizer name
+ GPU : str
+ the GPU ID(s) used for training
+ """
+
+ def __init__(
+ self,
+ d_feat=6,
+ hidden_size=64,
+ num_layers=2,
+ dropout=0.0,
+ n_epochs=200,
+ lr=0.001,
+ metric="",
+ batch_size=2000,
+ early_stop=20,
+ loss="mse",
+ optimizer="adam",
+ GPU="0",
+ seed=0,
+ **kwargs
+ ):
+ # Set logger.
+ self.logger = get_module_logger("ALSTM")
+ self.logger.info("ALSTM pytorch version...")
+
+ # set hyper-parameters.
+ self.d_feat = d_feat
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.n_epochs = n_epochs
+ self.lr = lr
+ self.metric = metric
+ self.batch_size = batch_size
+ self.early_stop = early_stop
+ self.optimizer = optimizer.lower()
+ self.loss = loss
+ self.visible_GPU = GPU
+ self.use_gpu = torch.cuda.is_available()
+ self.seed = seed
+
+ self.logger.info(
+ "ALSTM parameters setting:"
+ "\nd_feat : {}"
+ "\nhidden_size : {}"
+ "\nnum_layers : {}"
+ "\ndropout : {}"
+ "\nn_epochs : {}"
+ "\nlr : {}"
+ "\nmetric : {}"
+ "\nbatch_size : {}"
+ "\nearly_stop : {}"
+ "\noptimizer : {}"
+ "\nloss_type : {}"
+ "\nvisible_GPU : {}"
+ "\nuse_GPU : {}"
+ "\nseed : {}".format(
+ d_feat,
+ hidden_size,
+ num_layers,
+ dropout,
+ n_epochs,
+ lr,
+ metric,
+ batch_size,
+ early_stop,
+ optimizer.lower(),
+ loss,
+ GPU,
+ self.use_gpu,
+ seed,
+ )
+ )
+
+ self.ALSTM_model = ALSTMModel(
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
+ )
+ if optimizer.lower() == "adam":
+ self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
+ elif optimizer.lower() == "gd":
+ self.train_optimizer = optim.SGD(self.ALSTM_model.parameters(), lr=self.lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
+
+ self._fitted = False
+ if self.use_gpu:
+ self.ALSTM_model.cuda()
+ # set the visible GPU
+ if self.visible_GPU:
+ os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
+
+ def mse(self, pred, label):
+ loss = (pred - label) ** 2
+ return torch.mean(loss)
+
+ def loss_fn(self, pred, label):
+ mask = ~torch.isnan(label)
+
+ if self.loss == "mse":
+ return self.mse(pred[mask], label[mask])
+
+ raise ValueError("unknown loss `%s`" % self.loss)
+
+ def metric_fn(self, pred, label):
+
+ mask = torch.isfinite(label)
+
+ if self.metric == "" or self.metric == "loss":
+ return -self.loss_fn(pred[mask], label[mask])
+
+ raise ValueError("unknown metric `%s`" % self.metric)
+
+ def train_epoch(self, x_train, y_train):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+
+ self.ALSTM_model.train()
+
+ indices = np.arange(len(x_train_values))
+ np.random.shuffle(indices)
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float()
+ label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.ALSTM_model(feature)
+ loss = self.loss_fn(pred, label)
+
+ self.train_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.ALSTM_model.parameters(), 3.0)
+ self.train_optimizer.step()
+
+ def test_epoch(self, data_x, data_y):
+
+ # prepare training data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.ALSTM_model.eval()
+
+ scores = []
+ losses = []
+
+ indices = np.arange(len(x_values))
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float()
+ label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.ALSTM_model(feature)
+ loss = self.loss_fn(pred, label)
+ losses.append(loss.item())
+
+ score = self.metric_fn(pred, label)
+ scores.append(score.item())
+
+ return np.mean(losses), np.mean(scores)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
+
+ df_train, df_valid, df_test = dataset.prepare(
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
+ )
+
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ if save_path == None:
+ save_path = create_save_path(save_path)
+ stop_steps = 0
+ train_loss = 0
+ best_score = -np.inf
+ best_epoch = 0
+ evals_result["train"] = []
+ evals_result["valid"] = []
+
+ # train
+ self.logger.info("training...")
+ self._fitted = True
+
+ for step in range(self.n_epochs):
+ self.logger.info("Epoch%d:", step)
+ self.logger.info("training...")
+ self.train_epoch(x_train, y_train)
+ self.logger.info("evaluating...")
+ train_loss, train_score = self.test_epoch(x_train, y_train)
+ val_loss, val_score = self.test_epoch(x_valid, y_valid)
+ self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
+ evals_result["train"].append(train_score)
+ evals_result["valid"].append(val_score)
+
+ if val_score > best_score:
+ best_score = val_score
+ stop_steps = 0
+ best_epoch = step
+ best_param = copy.deepcopy(self.ALSTM_model.state_dict())
+ else:
+ stop_steps += 1
+ if stop_steps >= self.early_stop:
+ self.logger.info("early stop")
+ break
+
+ self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
+ self.ALSTM_model.load_state_dict(best_param)
+ torch.save(best_param, save_path)
+
+ if self.use_gpu:
+ torch.cuda.empty_cache()
+
+ def predict(self, dataset):
+ if not self._fitted:
+ raise ValueError("model is not fitted yet!")
+
+ x_test = dataset.prepare("test", col_set="feature")
+ index = x_test.index
+ self.ALSTM_model.eval()
+ x_values = x_test.values
+ sample_num = x_values.shape[0]
+ preds = []
+
+ for begin in range(sample_num)[:: self.batch_size]:
+
+ if sample_num - begin < self.batch_size:
+ end = sample_num
+ else:
+ end = begin + self.batch_size
+
+ x_batch = torch.from_numpy(x_values[begin:end]).float()
+
+ if self.use_gpu:
+ x_batch = x_batch.cuda()
+
+ with torch.no_grad():
+ if self.use_gpu:
+ pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
+ else:
+ pred = self.ALSTM_model(x_batch).detach().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
+
+
+class ALSTMModel(nn.Module):
+ def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type="GRU"):
+ super().__init__()
+ self.hid_size = hidden_size
+ self.input_size = d_feat
+ self.dropout = dropout
+ self.rnn_type = rnn_type
+ self.rnn_layer = num_layers
+ self._build_model()
+
+ def _build_model(self):
+ try:
+ klass = getattr(nn, self.rnn_type.upper())
+ except:
+ raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
+ self.net = nn.Sequential()
+ self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
+ self.net.add_module("act", nn.Tanh())
+ self.rnn = klass(
+ input_size=self.hid_size,
+ hidden_size=self.hid_size,
+ num_layers=self.rnn_layer,
+ batch_first=True,
+ dropout=self.dropout,
+ )
+ self.fc_out = nn.Linear(in_features=self.hid_size * 2, out_features=1)
+ self.att_net = nn.Sequential()
+ self.att_net.add_module(
+ "att_fc_in",
+ nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)),
+ )
+ self.att_net.add_module("att_dropout", torch.nn.Dropout(self.dropout))
+ self.att_net.add_module("att_act", nn.Tanh())
+ self.att_net.add_module(
+ "att_fc_out",
+ nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False),
+ )
+ self.att_net.add_module("att_softmax", nn.Softmax(dim=1))
+
+ def forward(self, inputs):
+ # inputs: [batch_size, input_size*input_day]
+ inputs = inputs.view(len(inputs), self.input_size, -1)
+ inputs = inputs.permute(0, 2, 1) # [batch, input_size, seq_len] -> [batch, seq_len, input_size]
+ rnn_out, _ = self.rnn(self.net(inputs)) # [batch, seq_len, num_directions * hidden_size]
+ attention_score = self.att_net(rnn_out) # [batch, seq_len, 1]
+ out_att = torch.mul(rnn_out, attention_score)
+ out_att = torch.sum(out_att, dim=1)
+ out = self.fc_out(
+ torch.cat((rnn_out[:, -1, :], out_att), dim=1)
+ ) # [batch, seq_len, num_directions * hidden_size] -> [batch, 1]
+ return out[..., 0]
diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py
new file mode 100644
index 000000000..e9cbcf9cb
--- /dev/null
+++ b/qlib/contrib/model/pytorch_gats.py
@@ -0,0 +1,394 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import pandas as pd
+import copy
+from ...utils import create_save_path
+from ...log import get_module_logger
+
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+from ...contrib.model.pytorch_lstm import LSTMModel
+from ...contrib.model.pytorch_gru import GRUModel
+
+
+class GATs(Model):
+ """GATs Model
+
+ Parameters
+ ----------
+ lr : float
+ learning rate
+ d_feat : int
+ input dimensions for each time step
+ metric : str
+ the evaluate metric used in early stop
+ optimizer : str
+ optimizer name
+ GPU : str
+ the GPU ID(s) used for training
+ """
+
+ def __init__(
+ self,
+ d_feat=6,
+ hidden_size=64,
+ num_layers=2,
+ dropout=0.0,
+ n_epochs=200,
+ lr=0.001,
+ metric="",
+ early_stop=20,
+ loss="mse",
+ base_model="GRU",
+ with_pretrain=True,
+ optimizer="adam",
+ GPU="0",
+ seed=0,
+ **kwargs
+ ):
+ # Set logger.
+ self.logger = get_module_logger("GATs")
+ self.logger.info("GATs pytorch version...")
+
+ # set hyper-parameters.
+ self.d_feat = d_feat
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.n_epochs = n_epochs
+ self.lr = lr
+ self.metric = metric
+ self.early_stop = early_stop
+ self.optimizer = optimizer.lower()
+ self.loss = loss
+ self.base_model = base_model
+ self.with_pretrain = with_pretrain
+ self.visible_GPU = GPU
+ self.use_gpu = torch.cuda.is_available()
+ self.seed = seed
+
+ self.logger.info(
+ "GATs parameters setting:"
+ "\nd_feat : {}"
+ "\nhidden_size : {}"
+ "\nnum_layers : {}"
+ "\ndropout : {}"
+ "\nn_epochs : {}"
+ "\nlr : {}"
+ "\nmetric : {}"
+ "\nearly_stop : {}"
+ "\noptimizer : {}"
+ "\nloss_type : {}"
+ "\nbase_model : {}"
+ "\nwith_pretrain : {}"
+ "\nvisible_GPU : {}"
+ "\nuse_GPU : {}"
+ "\nseed : {}".format(
+ d_feat,
+ hidden_size,
+ num_layers,
+ dropout,
+ n_epochs,
+ lr,
+ metric,
+ early_stop,
+ optimizer.lower(),
+ loss,
+ base_model,
+ with_pretrain,
+ GPU,
+ self.use_gpu,
+ seed,
+ )
+ )
+
+ self.GAT_model = GATModel(
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
+ base_model=self.base_model,
+ )
+ if optimizer.lower() == "adam":
+ self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
+ elif optimizer.lower() == "gd":
+ self.train_optimizer = optim.SGD(self.GAT_model.parameters(), lr=self.lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
+
+ self._fitted = False
+ if self.use_gpu:
+ self.GAT_model.cuda()
+ # set the visible GPU
+ if self.visible_GPU:
+ os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
+
+ def mse(self, pred, label):
+ loss = (pred - label) ** 2
+ return torch.mean(loss)
+
+ def loss_fn(self, pred, label):
+ mask = ~torch.isnan(label)
+
+ if self.loss == "mse":
+ return self.mse(pred[mask], label[mask])
+
+ raise ValueError("unknown loss `%s`" % self.loss)
+
+ def metric_fn(self, pred, label):
+
+ mask = torch.isfinite(label)
+
+ if self.metric == "" or self.metric == "loss":
+ return -self.loss_fn(pred[mask], label[mask])
+
+ raise ValueError("unknown metric `%s`" % self.metric)
+
+ def get_daily_inter(self, df, shuffle=False):
+ # organize the train data into daily batches
+ daily_count = df.groupby(level=0).size().values
+ daily_index = np.roll(np.cumsum(daily_count), 1)
+ daily_index[0] = 0
+ if shuffle:
+ # shuffle data
+ daily_shuffle = list(zip(daily_index, daily_count))
+ np.random.shuffle(daily_shuffle)
+ daily_index, daily_count = zip(*daily_shuffle)
+ return daily_index, daily_count
+
+ def train_epoch(self, x_train, y_train):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+ self.GAT_model.train()
+
+ # organize the train data into daily batches
+ daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
+
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ feature = torch.from_numpy(x_train_values[batch]).float()
+ label = torch.from_numpy(y_train_values[batch]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.GAT_model(feature)
+ loss = self.loss_fn(pred, label)
+
+ self.train_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.GAT_model.parameters(), 3.0)
+ self.train_optimizer.step()
+
+ def test_epoch(self, data_x, data_y):
+
+ # prepare training data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.GAT_model.eval()
+
+ scores = []
+ losses = []
+
+ # organize the test data into daily batches
+ daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
+
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ feature = torch.from_numpy(x_values[batch]).float()
+ label = torch.from_numpy(y_values[batch]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.GAT_model(feature)
+ loss = self.loss_fn(pred, label)
+ losses.append(loss.item())
+
+ score = self.metric_fn(pred, label)
+ scores.append(score.item())
+
+ return np.mean(losses), np.mean(scores)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
+
+ df_train, df_valid, df_test = dataset.prepare(
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
+ )
+
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ if save_path == None:
+ save_path = create_save_path(save_path)
+ stop_steps = 0
+ best_score = -np.inf
+ best_epoch = 0
+ evals_result["train"] = []
+ evals_result["valid"] = []
+
+ # load pretrained base_model
+ if self.with_pretrain:
+ self.logger.info("Loading pretrained model...")
+ if self.base_model == "LSTM":
+ pretrained_model = LSTMModel()
+ pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
+
+ elif self.base_model == "GRU":
+ pretrained_model = GRUModel()
+ pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
+
+ model_dict = self.GAT_model.state_dict()
+ pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
+ model_dict.update(pretrained_dict)
+ self.GAT_model.load_state_dict(model_dict)
+ self.logger.info("Loading pretrained model Done...")
+
+ # train
+ self.logger.info("training...")
+ self._fitted = True
+
+ for step in range(self.n_epochs):
+ self.logger.info("Epoch%d:", step)
+ self.logger.info("training...")
+ self.train_epoch(x_train, y_train)
+ self.logger.info("evaluating...")
+ train_loss, train_score = self.test_epoch(x_train, y_train)
+ val_loss, val_score = self.test_epoch(x_valid, y_valid)
+ self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
+ evals_result["train"].append(train_score)
+ evals_result["valid"].append(val_score)
+
+ if val_score > best_score:
+ best_score = val_score
+ stop_steps = 0
+ best_epoch = step
+ best_param = copy.deepcopy(self.GAT_model.state_dict())
+ else:
+ stop_steps += 1
+ if stop_steps >= self.early_stop:
+ self.logger.info("early stop")
+ break
+
+ self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
+ self.GAT_model.load_state_dict(best_param)
+ torch.save(best_param, save_path)
+
+ if self.use_gpu:
+ torch.cuda.empty_cache()
+
+ def predict(self, dataset):
+ if not self._fitted:
+ raise ValueError("model is not fitted yet!")
+
+ x_test = dataset.prepare("test", col_set="feature")
+ index = x_test.index
+ self.GAT_model.eval()
+ x_values = x_test.values
+ preds = []
+
+ # organize the data into daily batches
+ daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
+
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ x_batch = torch.from_numpy(x_values[batch]).float()
+
+ if self.use_gpu:
+ x_batch = x_batch.cuda()
+
+ with torch.no_grad():
+ if self.use_gpu:
+ pred = self.GAT_model(x_batch).detach().cpu().numpy()
+ else:
+ pred = self.GAT_model(x_batch).detach().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
+
+
+class GATModel(nn.Module):
+ def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
+ super().__init__()
+
+ if base_model == "GRU":
+ self.rnn = nn.GRU(
+ input_size=d_feat,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=dropout,
+ )
+ elif base_model == "LSTM":
+ self.rnn = nn.LSTM(
+ input_size=d_feat,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=dropout,
+ )
+ else:
+ raise ValueError("unknown base model name `%s`" % base_model)
+
+ self.hidden_size = hidden_size
+ self.d_feat = d_feat
+ self.transformation = nn.Linear(self.hidden_size, self.hidden_size)
+ self.a = nn.Parameter(torch.randn(self.hidden_size * 2, 1))
+ self.a.requires_grad = True
+ self.fc = nn.Linear(self.hidden_size, self.hidden_size)
+ self.fc_out = nn.Linear(hidden_size, 1)
+ self.leaky_relu = nn.LeakyReLU()
+ self.softmax = nn.Softmax(dim=1)
+
+ def cal_attention(self, x, y):
+ x = self.transformation(x)
+ y = self.transformation(y)
+
+ sample_num = x.shape[0]
+ dim = x.shape[1]
+ e_x = x.expand(sample_num, sample_num, dim)
+ e_y = torch.transpose(e_x, 0, 1)
+ attention_in = torch.cat((e_x, e_y), 2).view(-1, dim * 2)
+ self.a_t = torch.t(self.a)
+ attention_out = self.a_t.mm(torch.t(attention_in)).view(sample_num, sample_num)
+ attention_out = self.leaky_relu(attention_out)
+ att_weight = self.softmax(attention_out)
+ return att_weight
+
+ def forward(self, x):
+ # x: [N, F*T]
+ x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
+ x = x.permute(0, 2, 1) # [N, T, F]
+ out, _ = self.rnn(x)
+ hidden = out[:, -1, :]
+ att_weight = self.cal_attention(hidden, hidden)
+ hidden = att_weight.mm(hidden) + hidden
+ hidden = self.fc(hidden)
+ hidden = self.leaky_relu(hidden)
+ return self.fc_out(hidden).squeeze()
diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py
new file mode 100755
index 000000000..5daf4707e
--- /dev/null
+++ b/qlib/contrib/model/pytorch_gru.py
@@ -0,0 +1,334 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import pandas as pd
+import copy
+from sklearn.metrics import roc_auc_score, mean_squared_error
+import logging
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
+from ...log import get_module_logger, TimeInspector
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class GRU(Model):
+ """GRU Model
+
+ Parameters
+ ----------
+ d_feat : int
+ input dimension for each time step
+ metric: str
+ the evaluate metric used in early stop
+ optimizer : str
+ optimizer name
+ GPU : str
+ the GPU ID(s) used for training
+ """
+
+ def __init__(
+ self,
+ d_feat=6,
+ hidden_size=64,
+ num_layers=2,
+ dropout=0.0,
+ n_epochs=200,
+ lr=0.001,
+ metric="",
+ batch_size=2000,
+ early_stop=20,
+ loss="mse",
+ optimizer="adam",
+ GPU="0",
+ seed=0,
+ **kwargs
+ ):
+ # Set logger.
+ self.logger = get_module_logger("GRU")
+ self.logger.info("GRU pytorch version...")
+
+ # set hyper-parameters.
+ self.d_feat = d_feat
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.n_epochs = n_epochs
+ self.lr = lr
+ self.metric = metric
+ self.batch_size = batch_size
+ self.early_stop = early_stop
+ self.optimizer = optimizer.lower()
+ self.loss = loss
+ self.visible_GPU = GPU
+ self.use_gpu = torch.cuda.is_available()
+ self.seed = seed
+
+ self.logger.info(
+ "GRU parameters setting:"
+ "\nd_feat : {}"
+ "\nhidden_size : {}"
+ "\nnum_layers : {}"
+ "\ndropout : {}"
+ "\nn_epochs : {}"
+ "\nlr : {}"
+ "\nmetric : {}"
+ "\nbatch_size : {}"
+ "\nearly_stop : {}"
+ "\noptimizer : {}"
+ "\nloss_type : {}"
+ "\nvisible_GPU : {}"
+ "\nuse_GPU : {}"
+ "\nseed : {}".format(
+ d_feat,
+ hidden_size,
+ num_layers,
+ dropout,
+ n_epochs,
+ lr,
+ metric,
+ batch_size,
+ early_stop,
+ optimizer.lower(),
+ loss,
+ GPU,
+ self.use_gpu,
+ seed,
+ )
+ )
+
+ self.gru_model = GRUModel(
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
+ )
+ if optimizer.lower() == "adam":
+ self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)
+ elif optimizer.lower() == "gd":
+ self.train_optimizer = optim.SGD(self.gru_model.parameters(), lr=self.lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
+
+ self._fitted = False
+ if self.use_gpu:
+ self.gru_model.cuda()
+ # set the visible GPU
+ if self.visible_GPU:
+ os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
+
+ def mse(self, pred, label):
+ loss = (pred - label) ** 2
+ return torch.mean(loss)
+
+ def loss_fn(self, pred, label):
+ mask = ~torch.isnan(label)
+
+ if self.loss == "mse":
+ return self.mse(pred[mask], label[mask])
+
+ raise ValueError("unknown loss `%s`" % self.loss)
+
+ def metric_fn(self, pred, label):
+
+ mask = torch.isfinite(label)
+
+ if self.metric == "" or self.metric == "loss":
+ return -self.loss_fn(pred[mask], label[mask])
+
+ raise ValueError("unknown metric `%s`" % self.metric)
+
+ def train_epoch(self, x_train, y_train):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+
+ self.gru_model.train()
+
+ indices = np.arange(len(x_train_values))
+ np.random.shuffle(indices)
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float()
+ label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.gru_model(feature)
+ loss = self.loss_fn(pred, label)
+
+ self.train_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.gru_model.parameters(), 3.0)
+ self.train_optimizer.step()
+
+ def test_epoch(self, data_x, data_y):
+
+ # prepare training data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.gru_model.eval()
+
+ scores = []
+ losses = []
+
+ indices = np.arange(len(x_values))
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float()
+ label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.gru_model(feature)
+ loss = self.loss_fn(pred, label)
+ losses.append(loss.item())
+
+ score = self.metric_fn(pred, label)
+ scores.append(score.item())
+
+ return np.mean(losses), np.mean(scores)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
+
+ df_train, df_valid, df_test = dataset.prepare(
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
+ )
+
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ if save_path == None:
+ save_path = create_save_path(save_path)
+ stop_steps = 0
+ train_loss = 0
+ best_score = -np.inf
+ best_epoch = 0
+ evals_result["train"] = []
+ evals_result["valid"] = []
+
+ # train
+ self.logger.info("training...")
+ self._fitted = True
+
+ for step in range(self.n_epochs):
+ self.logger.info("Epoch%d:", step)
+ self.logger.info("training...")
+ self.train_epoch(x_train, y_train)
+ self.logger.info("evaluating...")
+ train_loss, train_score = self.test_epoch(x_train, y_train)
+ val_loss, val_score = self.test_epoch(x_valid, y_valid)
+ self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
+ evals_result["train"].append(train_score)
+ evals_result["valid"].append(val_score)
+
+ if val_score > best_score:
+ best_score = val_score
+ stop_steps = 0
+ best_epoch = step
+ best_param = copy.deepcopy(self.gru_model.state_dict())
+ else:
+ stop_steps += 1
+ if stop_steps >= self.early_stop:
+ self.logger.info("early stop")
+ break
+
+ self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
+ self.gru_model.load_state_dict(best_param)
+ torch.save(best_param, save_path)
+
+ if self.use_gpu:
+ torch.cuda.empty_cache()
+
+ def predict(self, dataset):
+ if not self._fitted:
+ raise ValueError("model is not fitted yet!")
+
+ x_test = dataset.prepare("test", col_set="feature")
+ index = x_test.index
+ self.gru_model.eval()
+ x_values = x_test.values
+ sample_num = x_values.shape[0]
+ preds = []
+
+ for begin in range(sample_num)[:: self.batch_size]:
+
+ if sample_num - begin < self.batch_size:
+ end = sample_num
+ else:
+ end = begin + self.batch_size
+
+ x_batch = torch.from_numpy(x_values[begin:end]).float()
+
+ if self.use_gpu:
+ x_batch = x_batch.cuda()
+
+ with torch.no_grad():
+ if self.use_gpu:
+ pred = self.gru_model(x_batch).detach().cpu().numpy()
+ else:
+ pred = self.gru_model(x_batch).detach().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
+
+
+class GRUModel(nn.Module):
+ def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
+ super().__init__()
+
+ self.rnn = nn.GRU(
+ input_size=d_feat,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=dropout,
+ )
+ self.fc_out = nn.Linear(hidden_size, 1)
+
+ self.d_feat = d_feat
+
+ def forward(self, x):
+ # x: [N, F*T]
+ x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
+ x = x.permute(0, 2, 1) # [N, T, F]
+ out, _ = self.rnn(x)
+ return self.fc_out(out[:, -1, :]).squeeze()
diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py
new file mode 100755
index 000000000..eef1680ec
--- /dev/null
+++ b/qlib/contrib/model/pytorch_lstm.py
@@ -0,0 +1,334 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import pandas as pd
+import copy
+from sklearn.metrics import roc_auc_score, mean_squared_error
+import logging
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
+from ...log import get_module_logger, TimeInspector
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class LSTM(Model):
+ """LSTM Model
+
+ Parameters
+ ----------
+ d_feat : int
+ input dimension for each time step
+ metric: str
+ the evaluate metric used in early stop
+ optimizer : str
+ optimizer name
+ GPU : str
+ the GPU ID(s) used for training
+ """
+
+ def __init__(
+ self,
+ d_feat=6,
+ hidden_size=64,
+ num_layers=2,
+ dropout=0.0,
+ n_epochs=200,
+ lr=0.001,
+ metric="",
+ batch_size=2000,
+ early_stop=20,
+ loss="mse",
+ optimizer="adam",
+ GPU="0",
+ seed=0,
+ **kwargs
+ ):
+ # Set logger.
+ self.logger = get_module_logger("LSTM")
+ self.logger.info("LSTM pytorch version...")
+
+ # set hyper-parameters.
+ self.d_feat = d_feat
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.n_epochs = n_epochs
+ self.lr = lr
+ self.metric = metric
+ self.batch_size = batch_size
+ self.early_stop = early_stop
+ self.optimizer = optimizer.lower()
+ self.loss = loss
+ self.visible_GPU = GPU
+ self.use_gpu = torch.cuda.is_available()
+ self.seed = seed
+
+ self.logger.info(
+ "LSTM parameters setting:"
+ "\nd_feat : {}"
+ "\nhidden_size : {}"
+ "\nnum_layers : {}"
+ "\ndropout : {}"
+ "\nn_epochs : {}"
+ "\nlr : {}"
+ "\nmetric : {}"
+ "\nbatch_size : {}"
+ "\nearly_stop : {}"
+ "\noptimizer : {}"
+ "\nloss_type : {}"
+ "\nvisible_GPU : {}"
+ "\nuse_GPU : {}"
+ "\nseed : {}".format(
+ d_feat,
+ hidden_size,
+ num_layers,
+ dropout,
+ n_epochs,
+ lr,
+ metric,
+ batch_size,
+ early_stop,
+ optimizer.lower(),
+ loss,
+ GPU,
+ self.use_gpu,
+ seed,
+ )
+ )
+
+ self.lstm_model = LSTMModel(
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
+ )
+ if optimizer.lower() == "adam":
+ self.train_optimizer = optim.Adam(self.lstm_model.parameters(), lr=self.lr)
+ elif optimizer.lower() == "gd":
+ self.train_optimizer = optim.SGD(self.lstm_model.parameters(), lr=self.lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
+
+ self._fitted = False
+ if self.use_gpu:
+ self.lstm_model.cuda()
+ # set the visible GPU
+ if self.visible_GPU:
+ os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
+
+ def mse(self, pred, label):
+ loss = (pred - label) ** 2
+ return torch.mean(loss)
+
+ def loss_fn(self, pred, label):
+ mask = ~torch.isnan(label)
+
+ if self.loss == "mse":
+ return self.mse(pred[mask], label[mask])
+
+ raise ValueError("unknown loss `%s`" % self.loss)
+
+ def metric_fn(self, pred, label):
+
+ mask = torch.isfinite(label)
+
+ if self.metric == "" or self.metric == "loss":
+ return -self.loss_fn(pred[mask], label[mask])
+
+ raise ValueError("unknown metric `%s`" % self.metric)
+
+ def train_epoch(self, x_train, y_train):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+
+ self.lstm_model.train()
+
+ indices = np.arange(len(x_train_values))
+ np.random.shuffle(indices)
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float()
+ label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.lstm_model(feature)
+ loss = self.loss_fn(pred, label)
+
+ self.train_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.lstm_model.parameters(), 3.0)
+ self.train_optimizer.step()
+
+ def test_epoch(self, data_x, data_y):
+
+ # prepare training data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.lstm_model.eval()
+
+ scores = []
+ losses = []
+
+ indices = np.arange(len(x_values))
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float()
+ label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.lstm_model(feature)
+ loss = self.loss_fn(pred, label)
+ losses.append(loss.item())
+
+ score = self.metric_fn(pred, label)
+ scores.append(score.item())
+
+ return np.mean(losses), np.mean(scores)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
+
+ df_train, df_valid, df_test = dataset.prepare(
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
+ )
+
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ if save_path == None:
+ save_path = create_save_path(save_path)
+ stop_steps = 0
+ train_loss = 0
+ best_score = -np.inf
+ best_epoch = 0
+ evals_result["train"] = []
+ evals_result["valid"] = []
+
+ # train
+ self.logger.info("training...")
+ self._fitted = True
+
+ for step in range(self.n_epochs):
+ self.logger.info("Epoch%d:", step)
+ self.logger.info("training...")
+ self.train_epoch(x_train, y_train)
+ self.logger.info("evaluating...")
+ train_loss, train_score = self.test_epoch(x_train, y_train)
+ val_loss, val_score = self.test_epoch(x_valid, y_valid)
+ self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
+ evals_result["train"].append(train_score)
+ evals_result["valid"].append(val_score)
+
+ if val_score > best_score:
+ best_score = val_score
+ stop_steps = 0
+ best_epoch = step
+ best_param = copy.deepcopy(self.lstm_model.state_dict())
+ else:
+ stop_steps += 1
+ if stop_steps >= self.early_stop:
+ self.logger.info("early stop")
+ break
+
+ self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
+ self.lstm_model.load_state_dict(best_param)
+ torch.save(best_param, save_path)
+
+ if self.use_gpu:
+ torch.cuda.empty_cache()
+
+ def predict(self, dataset):
+ if not self._fitted:
+ raise ValueError("model is not fitted yet!")
+
+ x_test = dataset.prepare("test", col_set="feature")
+ index = x_test.index
+ self.lstm_model.eval()
+ x_values = x_test.values
+ sample_num = x_values.shape[0]
+ preds = []
+
+ for begin in range(sample_num)[:: self.batch_size]:
+
+ if sample_num - begin < self.batch_size:
+ end = sample_num
+ else:
+ end = begin + self.batch_size
+
+ x_batch = torch.from_numpy(x_values[begin:end]).float()
+
+ if self.use_gpu:
+ x_batch = x_batch.cuda()
+
+ with torch.no_grad():
+ if self.use_gpu:
+ pred = self.lstm_model(x_batch).detach().cpu().numpy()
+ else:
+ pred = self.lstm_model(x_batch).detach().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
+
+
+class LSTMModel(nn.Module):
+ def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
+ super().__init__()
+
+ self.rnn = nn.LSTM(
+ input_size=d_feat,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=dropout,
+ )
+ self.fc_out = nn.Linear(hidden_size, 1)
+
+ self.d_feat = d_feat
+
+ def forward(self, x):
+ # x: [N, F*T]
+ x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
+ x = x.permute(0, 2, 1) # [N, T, F]
+ out, _ = self.rnn(x)
+ return self.fc_out(out[:, -1, :]).squeeze()
diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py
index 48b643bf8..d324e27aa 100644
--- a/qlib/contrib/model/pytorch_nn.py
+++ b/qlib/contrib/model/pytorch_nn.py
@@ -6,18 +6,21 @@ from __future__ import division
from __future__ import print_function
import os
+import logging
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, mean_squared_error
-import logging
-from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
-from ...log import get_module_logger, TimeInspector
import torch
import torch.nn as nn
import torch.optim as optim
-from .base import Model
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
+from ...log import get_module_logger, TimeInspector
+from ...workflow import R
class DNNModelPytorch(Model):
@@ -47,7 +50,7 @@ class DNNModelPytorch(Model):
self,
input_dim,
output_dim,
- layers=(256, 512, 768, 1024, 768, 512, 256, 128, 64),
+ layers=(256, 512, 768, 512, 256, 128, 64),
lr=0.001,
max_steps=300,
batch_size=2000,
@@ -76,7 +79,7 @@ class DNNModelPytorch(Model):
self.optimizer = optimizer.lower()
self.loss_type = loss
self.visible_GPU = GPU
- self.use_gpu = torch.cuda.is_available()
+ self.use_GPU = torch.cuda.is_available()
self.logger.info(
"DNN parameters setting:"
@@ -105,7 +108,7 @@ class DNNModelPytorch(Model):
loss,
eval_steps,
GPU,
- self.use_gpu,
+ self.use_GPU,
)
)
@@ -136,7 +139,7 @@ class DNNModelPytorch(Model):
)
self._fitted = False
- if self.use_gpu:
+ if self.use_GPU:
self.dnn_model.cuda()
# set the visible GPU
if self.visible_GPU:
@@ -144,20 +147,21 @@ class DNNModelPytorch(Model):
def fit(
self,
- x_train,
- y_train,
- x_valid,
- y_valid,
- w_train=None,
- w_valid=None,
+ dataset: DatasetH,
evals_result=dict(),
verbose=True,
save_path=None,
):
-
- if w_train is None:
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+ try:
+ wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
+ w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
+ except KeyError as e:
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
- if w_valid is None:
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
save_path = create_save_path(save_path)
@@ -166,7 +170,6 @@ class DNNModelPytorch(Model):
best_loss = np.inf
evals_result["train"] = []
evals_result["valid"] = []
-
# train
self.logger.info("training...")
self._fitted = True
@@ -176,13 +179,11 @@ class DNNModelPytorch(Model):
y_train_values = torch.from_numpy(y_train.values).float()
w_train_values = torch.from_numpy(w_train.values).float()
train_num = y_train_values.shape[0]
-
# prepare validation data
x_val_auto = torch.from_numpy(x_valid.values).float()
y_val_auto = torch.from_numpy(y_valid.values).float()
w_val_auto = torch.from_numpy(w_valid.values).float()
-
- if self.use_gpu:
+ if self.use_GPU:
x_val_auto = x_val_auto.cuda()
y_val_auto = y_val_auto.cuda()
w_val_auto = w_val_auto.cuda()
@@ -195,16 +196,15 @@ class DNNModelPytorch(Model):
loss = AverageMeter()
self.dnn_model.train()
self.train_optimizer.zero_grad()
-
choice = np.random.choice(train_num, self.batch_size)
x_batch_auto = x_train_values[choice]
y_batch_auto = y_train_values[choice]
w_batch_auto = w_train_values[choice]
- if self.use_gpu:
- x_batch_auto = x_batch_auto.float().cuda()
- y_batch_auto = y_batch_auto.float().cuda()
- w_batch_auto = w_batch_auto.float().cuda()
+ if self.use_GPU:
+ x_batch_auto = x_batch_auto.cuda()
+ y_batch_auto = y_batch_auto.cuda()
+ w_batch_auto = w_batch_auto.cuda()
# forward
preds = self.dnn_model(x_batch_auto)
@@ -212,10 +212,10 @@ class DNNModelPytorch(Model):
cur_loss.backward()
self.train_optimizer.step()
loss.update(cur_loss.item())
+ R.log_metrics(train_loss=loss.avg, step=step)
# validation
train_loss += loss.val
- # print(loss.val)
if step and step % self.eval_steps == 0:
stop_steps += 1
train_loss /= self.eval_steps
@@ -228,6 +228,7 @@ class DNNModelPytorch(Model):
preds = self.dnn_model(x_val_auto)
cur_loss_val = self.get_loss(preds, w_val_auto, y_val_auto, self.loss_type)
loss_val.update(cur_loss_val.item())
+ R.log_metrics(val_loss=loss_val.val, step=step)
if verbose:
self.logger.info(
"[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
@@ -250,7 +251,7 @@ class DNNModelPytorch(Model):
# restore the optimal parameters after training ??
self.dnn_model.load_state_dict(torch.load(save_path))
- if self.use_gpu:
+ if self.use_GPU:
torch.cuda.empty_cache()
def get_loss(self, pred, w, target, loss_type):
@@ -264,27 +265,21 @@ class DNNModelPytorch(Model):
else:
raise NotImplementedError("loss {} is not supported!".format(loss_type))
- def predict(self, x_test):
+ def predict(self, dataset):
if not self._fitted:
raise ValueError("model is not fitted yet!")
- x_test = torch.from_numpy(x_test.values).float()
- if self.use_gpu:
+ x_test_pd = dataset.prepare("test", col_set="feature")
+ x_test = torch.from_numpy(x_test_pd.values).float()
+ if self.use_GPU:
x_test = x_test.cuda()
self.dnn_model.eval()
with torch.no_grad():
- if self.use_gpu:
+ if self.use_GPU:
preds = self.dnn_model(x_test).detach().cpu().numpy()
else:
preds = self.dnn_model(x_test).detach().numpy()
- return preds
-
- def score(self, x_test, y_test, w_test=None):
- # Remove rows from x, y and w, which contain Nan in any columns in y_test.
- x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
- preds = self.predict(x_test)
- w_test_weight = None if w_test is None else w_test.values
- return self._scorer(y_test.values, preds, sample_weight=w_test_weight)
+ return pd.Series(np.squeeze(preds), index=x_test_pd.index)
def save(self, filename, **kwargs):
with save_multiple_parts_file(filename) as model_dir:
@@ -303,9 +298,6 @@ class DNNModelPytorch(Model):
self.dnn_model.load_state_dict(torch.load(_model_path))
self._fitted = True
- def finetune(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs):
- self.fit(x_train, y_train, x_valid, y_valid, w_train=w_train, w_valid=w_valid, **kwargs)
-
class AverageMeter(object):
"""Computes and stores the average and current value"""
@@ -335,7 +327,7 @@ class Net(nn.Module):
dnn_layers.append(drop_input)
for i, (input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
fc = nn.Linear(input_dim, hidden_units)
- activation = nn.ReLU()
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
bn = nn.BatchNorm1d(hidden_units)
seq = nn.Sequential(fc, bn, activation)
dnn_layers.append(seq)
@@ -358,7 +350,7 @@ class Net(nn.Module):
def _weight_init(self):
for m in self.modules():
if isinstance(m, nn.Linear):
- nn.init.xavier_normal_(m.weight, gain=1)
+ nn.init.kaiming_normal_(m.weight, a=0.1, mode="fan_in", nonlinearity="leaky_relu")
def forward(self, x):
cur_output = x
diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py
new file mode 100644
index 000000000..228c0aee5
--- /dev/null
+++ b/qlib/contrib/model/pytorch_sfm.py
@@ -0,0 +1,479 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import pandas as pd
+import copy
+from sklearn.metrics import roc_auc_score, mean_squared_error
+import logging
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
+from ...log import get_module_logger, TimeInspector
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+import torch.optim as optim
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class SFM_Model(nn.Module):
+ def __init__(
+ self,
+ d_feat=6,
+ output_dim=1,
+ freq_dim=10,
+ hidden_size=64,
+ dropout_W=0.0,
+ dropout_U=0.0,
+ device="cpu",
+ ):
+ super().__init__()
+
+ self.input_dim = d_feat
+ self.output_dim = output_dim
+ self.freq_dim = freq_dim
+ self.hidden_dim = hidden_size
+ self.device = device
+
+ self.W_i = nn.Parameter(init.xavier_uniform_(torch.empty((self.input_dim, self.hidden_dim))))
+ self.U_i = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))
+ self.b_i = nn.Parameter(torch.zeros(self.hidden_dim))
+
+ self.W_ste = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))
+ self.U_ste = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))
+ self.b_ste = nn.Parameter(torch.ones(self.hidden_dim))
+
+ self.W_fre = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.freq_dim)))
+ self.U_fre = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.freq_dim)))
+ self.b_fre = nn.Parameter(torch.ones(self.freq_dim))
+
+ self.W_c = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))
+ self.U_c = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))
+ self.b_c = nn.Parameter(torch.zeros(self.hidden_dim))
+
+ self.W_o = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))
+ self.U_o = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))
+ self.b_o = nn.Parameter(torch.zeros(self.hidden_dim))
+
+ self.U_a = nn.Parameter(init.orthogonal_(torch.empty(self.freq_dim, 1)))
+ self.b_a = nn.Parameter(torch.zeros(self.hidden_dim))
+
+ self.W_p = nn.Parameter(init.xavier_uniform_(torch.empty(self.hidden_dim, self.output_dim)))
+ self.b_p = nn.Parameter(torch.zeros(self.output_dim))
+
+ self.activation = nn.Tanh()
+ self.inner_activation = nn.Hardsigmoid()
+ self.dropout_W, self.dropout_U = (dropout_W, dropout_U)
+ self.fc_out = nn.Linear(self.output_dim, 1)
+
+ self.states = []
+
+ def forward(self, input):
+ input = input.reshape(len(input), self.input_dim, -1) # [N, F, T]
+ input = input.permute(0, 2, 1) # [N, T, F]
+ time_step = input.shape[1]
+
+ for ts in range(time_step):
+ x = input[:, ts, :]
+ if len(self.states) == 0: # hasn't initialized yet
+ self.init_states(x)
+ self.get_constants(x)
+ p_tm1 = self.states[0]
+ h_tm1 = self.states[1]
+ S_re_tm1 = self.states[2]
+ S_im_tm1 = self.states[3]
+ time_tm1 = self.states[4]
+ B_U = self.states[5]
+ B_W = self.states[6]
+ frequency = self.states[7]
+
+ x_i = torch.matmul(x * B_W[0], self.W_i) + self.b_i
+ x_ste = torch.matmul(x * B_W[0], self.W_ste) + self.b_ste
+ x_fre = torch.matmul(x * B_W[0], self.W_fre) + self.b_fre
+ x_c = torch.matmul(x * B_W[0], self.W_c) + self.b_c
+ x_o = torch.matmul(x * B_W[0], self.W_o) + self.b_o
+
+ i = self.inner_activation(x_i + torch.matmul(h_tm1 * B_U[0], self.U_i))
+ ste = self.inner_activation(x_ste + torch.matmul(h_tm1 * B_U[0], self.U_ste))
+ fre = self.inner_activation(x_fre + torch.matmul(h_tm1 * B_U[0], self.U_fre))
+
+ ste = torch.reshape(ste, (-1, self.hidden_dim, 1))
+ fre = torch.reshape(fre, (-1, 1, self.freq_dim))
+
+ f = ste * fre
+
+ c = i * self.activation(x_c + torch.matmul(h_tm1 * B_U[0], self.U_c))
+
+ time = time_tm1 + 1
+
+ omega = torch.tensor(2 * np.pi) * time * frequency
+
+ re = torch.cos(omega)
+ im = torch.sin(omega)
+
+ c = torch.reshape(c, (-1, self.hidden_dim, 1))
+
+ S_re = f * S_re_tm1 + c * re
+ S_im = f * S_im_tm1 + c * im
+
+ A = torch.square(S_re) + torch.square(S_im)
+
+ A = torch.reshape(A, (-1, self.freq_dim)).float()
+ A_a = torch.matmul(A * B_U[0], self.U_a)
+ A_a = torch.reshape(A_a, (-1, self.hidden_dim))
+ a = self.activation(A_a + self.b_a)
+
+ o = self.inner_activation(x_o + torch.matmul(h_tm1 * B_U[0], self.U_o))
+
+ h = o * a
+ p = torch.matmul(h, self.W_p) + self.b_p
+
+ self.states = [p, h, S_re, S_im, time, None, None, None]
+ self.states = []
+ return self.fc_out(p).squeeze()
+
+ def init_states(self, x):
+ reducer_f = torch.zeros((self.hidden_dim, self.freq_dim)).to(self.device)
+ reducer_p = torch.zeros((self.hidden_dim, self.output_dim)).to(self.device)
+
+ init_state_h = torch.zeros(self.hidden_dim).to(self.device)
+ init_state_p = torch.matmul(init_state_h, reducer_p)
+
+ init_state = torch.zeros_like(init_state_h).to(self.device)
+ init_freq = torch.matmul(init_state_h, reducer_f)
+
+ init_state = torch.reshape(init_state, (-1, self.hidden_dim, 1))
+ init_freq = torch.reshape(init_freq, (-1, 1, self.freq_dim))
+
+ init_state_S_re = init_state * init_freq
+ init_state_S_im = init_state * init_freq
+
+ init_state_time = torch.tensor(0).to(self.device)
+
+ self.states = [
+ init_state_p,
+ init_state_h,
+ init_state_S_re,
+ init_state_S_im,
+ init_state_time,
+ None,
+ None,
+ None,
+ ]
+
+ def get_constants(self, x):
+ constants = []
+ constants.append([torch.tensor(1.0).to(self.device) for _ in range(6)])
+ constants.append([torch.tensor(1.0).to(self.device) for _ in range(7)])
+ array = np.array([float(ii) / self.freq_dim for ii in range(self.freq_dim)])
+ constants.append(torch.tensor(array).to(self.device))
+
+ self.states[5:] = constants
+
+
+class SFM(Model):
+ """SFM Model
+
+ Parameters
+ ----------
+ input_dim : int
+ input dimension
+ output_dim : int
+ output dimension
+ lr : float
+ learning rate
+ optimizer : str
+ optimizer name
+ GPU : str
+ the GPU ID(s) used for training
+ """
+
+ def __init__(
+ self,
+ d_feat=6,
+ hidden_size=64,
+ output_dim=1,
+ freq_dim=10,
+ dropout_W=0.0,
+ dropout_U=0.0,
+ n_epochs=200,
+ lr=0.001,
+ metric="",
+ batch_size=2000,
+ early_stop=20,
+ eval_steps=5,
+ loss="mse",
+ optimizer="gd",
+ GPU="0",
+ seed=0,
+ **kwargs
+ ):
+ # Set logger.
+ self.logger = get_module_logger("SFM")
+ self.logger.info("SFM pytorch version...")
+
+ # set hyper-parameters.
+ self.d_feat = d_feat
+ self.hidden_size = hidden_size
+ self.output_dim = output_dim
+ self.freq_dim = freq_dim
+ self.dropout_W = dropout_W
+ self.dropout_U = dropout_U
+ self.n_epochs = n_epochs
+ self.lr = lr
+ self.metric = metric
+ self.batch_size = batch_size
+ self.early_stop = early_stop
+ self.eval_steps = eval_steps
+ self.optimizer = optimizer.lower()
+ self.loss = loss
+ self.device = "cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu"
+ self.use_gpu = torch.cuda.is_available()
+ self.seed = seed
+
+ self.logger.info(
+ "SFM parameters setting:"
+ "\nd_feat : {}"
+ "\nhidden_size : {}"
+ "\noutput_size : {}"
+ "\nfrequency_dimension : {}"
+ "\ndropout_W: {}"
+ "\ndropout_U: {}"
+ "\nn_epochs : {}"
+ "\nlr : {}"
+ "\nmetric : {}"
+ "\nbatch_size : {}"
+ "\nearly_stop : {}"
+ "\neval_steps : {}"
+ "\noptimizer : {}"
+ "\nloss_type : {}"
+ "\nvisible_GPU : {}"
+ "\nuse_GPU : {}"
+ "\nseed : {}".format(
+ d_feat,
+ hidden_size,
+ output_dim,
+ freq_dim,
+ dropout_W,
+ dropout_U,
+ n_epochs,
+ lr,
+ metric,
+ batch_size,
+ early_stop,
+ eval_steps,
+ optimizer.lower(),
+ loss,
+ GPU,
+ self.use_gpu,
+ seed,
+ )
+ )
+
+ self.sfm_model = SFM_Model(
+ d_feat=self.d_feat,
+ output_dim=self.output_dim,
+ hidden_size=self.hidden_size,
+ freq_dim=self.freq_dim,
+ dropout_W=self.dropout_W,
+ dropout_U=self.dropout_U,
+ device=self.device,
+ )
+ if optimizer.lower() == "adam":
+ self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr)
+ elif optimizer.lower() == "gd":
+ self.train_optimizer = optim.SGD(self.sfm_model.parameters(), lr=self.lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
+
+ self._fitted = False
+ self.sfm_model.to(self.device)
+
+ def test_epoch(self, data_x, data_y):
+
+ # prepare training data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.sfm_model.eval()
+
+ scores = []
+ losses = []
+
+ indices = np.arange(len(x_values))
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
+ label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
+
+ pred = self.sfm_model(feature)
+ loss = self.loss_fn(pred, label)
+ losses.append(loss.item())
+
+ score = self.metric_fn(pred, label)
+ scores.append(score.item())
+
+ return np.mean(losses), np.mean(scores)
+
+ def train_epoch(self, x_train, y_train):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+
+ self.sfm_model.train()
+
+ indices = np.arange(len(x_train_values))
+ np.random.shuffle(indices)
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
+ label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
+
+ pred = self.sfm_model(feature)
+ loss = self.loss_fn(pred, label)
+
+ self.train_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.sfm_model.parameters(), 3.0)
+ self.train_optimizer.step()
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
+
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ stop_steps = 0
+ train_loss = 0
+ best_score = -np.inf
+ best_epoch = 0
+ evals_result["train"] = []
+ evals_result["valid"] = []
+
+ # train
+ self.logger.info("training...")
+ self._fitted = True
+
+ for step in range(self.n_epochs):
+ self.logger.info("Epoch%d:", step)
+ self.logger.info("training...")
+ self.train_epoch(x_train, y_train)
+ self.logger.info("evaluating...")
+ train_loss, train_score = self.test_epoch(x_train, y_train)
+ val_loss, val_score = self.test_epoch(x_valid, y_valid)
+ self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
+ evals_result["train"].append(train_score)
+ evals_result["valid"].append(val_score)
+
+ if val_score > best_score:
+ best_score = val_score
+ stop_steps = 0
+ best_epoch = step
+ best_param = copy.deepcopy(self.sfm_model.state_dict())
+ else:
+ stop_steps += 1
+ if stop_steps >= self.early_stop:
+ self.logger.info("early stop")
+ break
+ self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
+ if self.device != "cpu":
+ torch.cuda.empty_cache()
+
+ def mse(self, pred, label):
+ loss = (pred - label) ** 2
+ return torch.mean(loss)
+
+ def loss_fn(self, pred, label):
+ mask = ~torch.isnan(label)
+
+ if self.loss == "mse":
+ return self.mse(pred[mask], label[mask])
+
+ raise ValueError("unknown loss `%s`" % self.loss)
+
+ def metric_fn(self, pred, label):
+
+ mask = torch.isfinite(label)
+
+ if self.metric == "" or self.metric == "loss":
+ return -self.loss_fn(pred[mask], label[mask])
+
+ raise ValueError("unknown metric `%s`" % self.metric)
+
+ def predict(self, dataset):
+ if not self._fitted:
+ raise ValueError("model is not fitted yet!")
+
+ x_test = dataset.prepare("test", col_set="feature")
+ index = x_test.index
+ self.sfm_model.eval()
+ x_values = x_test.values
+ sample_num = x_values.shape[0]
+ preds = []
+
+ for begin in range(sample_num)[:: self.batch_size]:
+ if sample_num - begin < self.batch_size:
+ end = sample_num
+ else:
+ end = begin + self.batch_size
+
+ x_batch = torch.from_numpy(x_values[begin:end]).float()
+
+ if self.device != "cpu":
+ x_batch = x_batch.to(self.device)
+
+ with torch.no_grad():
+ pred = self.sfm_model(x_batch).detach().cpu().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py
new file mode 100755
index 000000000..ba2e5789b
--- /dev/null
+++ b/qlib/contrib/model/xgboost.py
@@ -0,0 +1,64 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import numpy as np
+import pandas as pd
+import xgboost as xgb
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class XGBModel(Model):
+ """XGBModel Model"""
+
+ def __init__(self, **kwargs):
+ self._params = {}
+ self._params.update(kwargs)
+ self.model = None
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ num_boost_round=1000,
+ early_stopping_rounds=50,
+ verbose_eval=20,
+ evals_result=dict(),
+ **kwargs
+ ):
+
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ # Lightgbm need 1D array as its label
+ if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
+ y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
+ else:
+ raise ValueError("XGBoost doesn't support multi-label training")
+
+ dtrain = xgb.DMatrix(x_train.values, label=y_train_1d)
+ dvalid = xgb.DMatrix(x_valid.values, label=y_valid_1d)
+ self.model = xgb.train(
+ self._params,
+ dtrain=dtrain,
+ num_boost_round=num_boost_round,
+ evals=[(dtrain, "train"), (dvalid, "valid")],
+ early_stopping_rounds=early_stopping_rounds,
+ verbose_eval=verbose_eval,
+ evals_result=evals_result,
+ **kwargs
+ )
+ evals_result["train"] = list(evals_result["train"].values())[0]
+ evals_result["valid"] = list(evals_result["valid"].values())[0]
+
+ def predict(self, dataset):
+ if self.model is None:
+ raise ValueError("model is not fitted yet!")
+ x_test = dataset.prepare("test", col_set="feature")
+ return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
diff --git a/qlib/contrib/online/__init__.py b/qlib/contrib/online/__init__.py
index e69de29bb..71389882e 100644
--- a/qlib/contrib/online/__init__.py
+++ b/qlib/contrib/online/__init__.py
@@ -0,0 +1,18 @@
+'''
+TODO:
+
+- Online needs that the model have such method
+ def get_data_with_date(self, date, **kwargs):
+ """
+ Will be called in online module
+ need to return the data that used to predict the label (score) of stocks at date.
+
+ :param
+ date: pd.Timestamp
+ predict date
+ :return:
+ data: the input data that used to predict the label (score) of stocks at predict date.
+ """
+ raise NotImplementedError("get_data_with_date for this model is not implemented.")
+
+'''
diff --git a/qlib/contrib/online/manager.py b/qlib/contrib/online/manager.py
index 7e9c766e8..cf850b9da 100644
--- a/qlib/contrib/online/manager.py
+++ b/qlib/contrib/online/manager.py
@@ -11,7 +11,7 @@ from ..backtest.account import Account
from ..backtest.exchange import Exchange
from .user import User
from .utils import load_instance
-from .utils import save_instance, init_instance_by_config
+from ...utils import save_instance, init_instance_by_config
class UserManager:
diff --git a/qlib/contrib/online/utils.py b/qlib/contrib/online/utils.py
index cf08e4dbe..611af63e4 100644
--- a/qlib/contrib/online/utils.py
+++ b/qlib/contrib/online/utils.py
@@ -7,7 +7,7 @@ import yaml
import pandas as pd
from ...data import D
from ...log import get_module_logger
-from ...utils import get_module_by_module_path
+from ...utils import get_module_by_module_path, init_instance_by_config
from ...utils import get_next_trading_date
from ..backtest.exchange import Exchange
@@ -45,21 +45,6 @@ def save_instance(instance, file_path):
pickle.dump(instance, fr)
-def init_instance_by_config(config):
- """
- generate an instance with settings in config
- Parameter
- config : dict
- python dict indicate a init parameters to create an item
- :return
- An instance
- """
- module = get_module_by_module_path(config["module_path"])
- instance_class = getattr(module, config["class"])
- instance = instance_class(**config["args"])
- return instance
-
-
def create_user_folder(path):
path = pathlib.Path(path)
if path.exists():
diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py
index 1c69145db..1cb14d261 100644
--- a/qlib/contrib/report/analysis_model/analysis_model_performance.py
+++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py
@@ -252,7 +252,7 @@ def model_performance_graph(
"""Model performance
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score,
- label]**. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1")
+ label]**. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1").
.. code-block:: python
@@ -266,13 +266,13 @@ def model_performance_graph(
:param lag: `pred.groupby(level='instrument')['score'].shift(lag)`. It will be only used in the auto-correlation computing.
- :param N: group number, default 5
- :param reverse: if `True`, `pred['score'] *= -1`
- :param rank: if **True**, calculate rank ic
- :param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover']
- :param show_notebook: whether to display graphics in notebook, the default is `True`
- :param show_nature_day: whether to display the abscissa of non-trading day
- :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list
+ :param N: group number, default 5.
+ :param reverse: if `True`, `pred['score'] *= -1`.
+ :param rank: if **True**, calculate rank ic.
+ :param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'].
+ :param show_notebook: whether to display graphics in notebook, the default is `True`.
+ :param show_nature_day: whether to display the abscissa of non-trading day.
+ :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list.
"""
figure_list = []
for graph_name in graph_names:
diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py
index 941785e83..abb68ea60 100644
--- a/qlib/contrib/report/analysis_position/cumulative_return.py
+++ b/qlib/contrib/report/analysis_position/cumulative_return.py
@@ -218,10 +218,10 @@ def cumulative_return_graph(
Graph desc:
- - Axis X: Trading day
+ - Axis X: Trading day.
- Axis Y:
- - Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`
- - Below axis Y: Daily weight sum
+ - Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`.
+ - Below axis Y: Daily weight sum.
- In the **sell** graph, `y < 0` stands for profit; in other cases, `y > 0` stands for profit.
- In the **buy_minus_sell** graph, the **y** value of the **weight** graph at the bottom is `buy_weight + sell_weight`.
- In each graph, the **red line** in the histogram on the right represents the average.
diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py
index e2f7fe1cf..72a358adc 100644
--- a/qlib/contrib/report/analysis_position/rank_label.py
+++ b/qlib/contrib/report/analysis_position/rank_label.py
@@ -97,9 +97,9 @@ def rank_label_graph(
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
- :param position: position data; **qlib.contrib.backtest.backtest.backtest** result
+ :param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
:param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**.
- **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`
+ **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`.
.. code-block:: python
@@ -115,7 +115,7 @@ def rank_label_graph(
:param start_date: start date
:param end_date: end_date
- :param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures
+ :param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures.
:return:
"""
position = copy.deepcopy(position)
diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py
index e8bb5313f..438aab8b9 100644
--- a/qlib/contrib/report/analysis_position/report.py
+++ b/qlib/contrib/report/analysis_position/report.py
@@ -75,11 +75,12 @@ def _report_figure(df: pd.DataFrame) -> [list, tuple]:
max_start_date, max_end_date = _calculate_maximum(report_df)
ex_max_start_date, ex_max_end_date = _calculate_maximum(report_df, True)
+ index_name = report_df.index.name
_temp_df = report_df.reset_index()
_temp_df.loc[-1] = 0
_temp_df = _temp_df.shift(1)
- _temp_df.loc[0, "index"] = "T0"
- _temp_df.set_index("index", inplace=True)
+ _temp_df.loc[0, index_name] = "T0"
+ _temp_df.set_index(index_name, inplace=True)
_temp_df.iloc[0] = 0
report_df = _temp_df
@@ -99,13 +100,13 @@ def _report_figure(df: pd.DataFrame) -> [list, tuple]:
("cum_ex_return_wo_cost_mdd", dict(row=7, col=1, graph_kwargs=_temp_fill_args)),
]
- _subplot_layout = dict(
- xaxis=dict(showline=True, type="category", tickangle=45),
- yaxis=dict(zeroline=True, showline=True, showticklabels=True),
- )
- for i in range(2, 8):
+ _subplot_layout = dict()
+ for i in range(1, 8):
# yaxis
_subplot_layout.update({"yaxis{}".format(i): dict(zeroline=True, showline=True, showticklabels=True)})
+ _show_line = i == 7
+ _subplot_layout.update({"xaxis{}".format(i): dict(showline=_show_line, type="category", tickangle=45)})
+
_layout_style = dict(
height=1200,
title=" ",
@@ -185,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
qcr.report_graph(report_normal_df)
- :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**
+ :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.
.. code-block:: python
@@ -199,8 +200,8 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
- :param show_notebook: whether to display graphics in notebook, the default is **True**
- :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list
+ :param show_notebook: whether to display graphics in notebook, the default is **True**.
+ :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
"""
report_df = report_df.copy()
fig_list = _report_figure(report_df)
diff --git a/qlib/contrib/report/analysis_position/risk_analysis.py b/qlib/contrib/report/analysis_position/risk_analysis.py
index 89650c39e..051c78035 100644
--- a/qlib/contrib/report/analysis_position/risk_analysis.py
+++ b/qlib/contrib/report/analysis_position/risk_analysis.py
@@ -116,7 +116,11 @@ def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]:
if analysis_df is None:
return []
- _figure = SubplotsGraph(_get_all_risk_analysis(analysis_df), kind_map=dict(kind="BarGraph", kwargs={})).figure
+ _figure = SubplotsGraph(
+ _get_all_risk_analysis(analysis_df),
+ kind_map=dict(kind="BarGraph", kwargs={}),
+ subplots_kwargs={"rows": 4, "cols": 1},
+ ).figure
return (_figure,)
@@ -214,7 +218,7 @@ def risk_analysis_graph(
max_drawdown -0.088263
- :param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**
+ :param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**.
.. code-block:: python
@@ -228,7 +232,7 @@ def risk_analysis_graph(
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
- :param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**
+ :param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**.
.. code-block:: python
@@ -242,7 +246,7 @@ def risk_analysis_graph(
2017-01-10 0.000824 -0.001944 -0.001120
- :param show_notebook: Whether to display graphics in a notebook, default **True**
+ :param show_notebook: Whether to display graphics in a notebook, default **True**.
If True, show graph in notebook
If False, return graph figure
:return:
diff --git a/qlib/contrib/report/analysis_position/score_ic.py b/qlib/contrib/report/analysis_position/score_ic.py
index 9a2fc8560..a6a7a8b0e 100644
--- a/qlib/contrib/report/analysis_position/score_ic.py
+++ b/qlib/contrib/report/analysis_position/score_ic.py
@@ -36,7 +36,7 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
analysis_position.score_ic_graph(pred_label)
- :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**
+ :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**.
.. code-block:: python
@@ -49,8 +49,8 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
2017-12-15 -0.102778 -0.102778
- :param show_notebook: whether to display graphics in notebook, the default is **True**
- :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list
+ :param show_notebook: whether to display graphics in notebook, the default is **True**.
+ :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
"""
_ic_df = _get_score_ic(pred_label)
# FIXME: support HIGH-FREQ
diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py
index 082eafa49..3fa688d36 100644
--- a/qlib/contrib/report/graph.py
+++ b/qlib/contrib/report/graph.py
@@ -11,7 +11,7 @@ import pandas as pd
import plotly.offline as py
import plotly.graph_objs as go
-from plotly.tools import make_subplots
+from plotly.subplots import make_subplots
from plotly.figure_factory import create_distplot
from ...utils import get_module_by_module_path
@@ -96,7 +96,19 @@ class BaseGraph(object):
"""
py.init_notebook_mode()
for _fig in figure_list:
- py.iplot(_fig)
+ # NOTE: displays figures: https://plotly.com/python/renderers/
+ # default: plotly_mimetype+notebook
+ # support renderers: import plotly.io as pio; print(pio.renderers)
+ renderer = None
+ try:
+ # in notebook
+ _ipykernel = str(type(get_ipython()))
+ if "google.colab" in _ipykernel:
+ renderer = "colab"
+ except NameError:
+ pass
+
+ _fig.show(renderer=renderer)
def _get_layout(self) -> go.Layout:
"""
@@ -125,7 +137,10 @@ class BaseGraph(object):
:return:
"""
- return go.Figure(data=self.data, layout=self._get_layout())
+ _figure = go.Figure(data=self.data, layout=self._get_layout())
+ # NOTE: using default 3.x theme
+ _figure["layout"].update(template=None)
+ return _figure
class ScatterGraph(BaseGraph):
@@ -357,13 +372,14 @@ class SubplotsGraph(object):
# _item.pop('yaxis', None)
for _g_obj in _graph_data:
- self._figure.append_trace(_g_obj, row=row, col=col)
+ self._figure.add_trace(_g_obj, row=row, col=col)
if self._sub_graph_layout is not None:
for k, v in self._sub_graph_layout.items():
self._figure["layout"][k].update(v)
- self._figure["layout"].update(self._layout)
+ # NOTE: using default 3.x theme
+ self._figure["layout"].update(self._layout, template=None)
@property
def figure(self):
diff --git a/qlib/contrib/strategy/strategy.py b/qlib/contrib/strategy/strategy.py
index 339af31e5..23e8b5185 100644
--- a/qlib/contrib/strategy/strategy.py
+++ b/qlib/contrib/strategy/strategy.py
@@ -11,6 +11,7 @@ from ...utils import get_pre_trading_date
from .order_generator import OrderGenWInteract
+# TODO: The base strategies will be moved out of contrib to core code
class BaseStrategy:
def __init__(self):
pass
@@ -24,33 +25,35 @@ class BaseStrategy:
return 0.95
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
- """Parameter
- score_series : pd.Seires
- stock_id , score
- current : Position()
- current state of position
- DO NOT directly change the state of current
- trade_exchange : Exchange()
- trade exchange
- pred_date : pd.Timestamp
- predict date
- trade_date : pd.Timestamp
- trade date
-
+ """
DO NOT directly change the state of current
+
+ Parameters
+ -----------
+ score_series : pd.Seires
+ stock_id , score.
+ current : Position()
+ current state of position.
+ DO NOT directly change the state of current.
+ trade_exchange : Exchange()
+ trade exchange.
+ pred_date : pd.Timestamp
+ predict date.
+ trade_date : pd.Timestamp
+ trade date.
"""
pass
def update(self, score_series, pred_date, trade_date):
"""User can use this method to update strategy state each trade date.
- Parameter
- ---------
+ Parameters
+ -----------
score_series : pd.Series
- stock_id , score
+ stock_id , score.
pred_date : pd.Timestamp
- oredict date
+ oredict date.
trade_date : pd.Timestamp
- trade date
+ trade date.
"""
pass
@@ -64,7 +67,7 @@ class BaseStrategy:
"""
This method only be used in 'online' module, it will generate the *args to initial the strategy.
:param
- mode : model used in 'online' module
+ mode : model used in 'online' module.
"""
return {}
@@ -79,7 +82,7 @@ class StrategyWrapper:
def __init__(self, inner_strategy):
"""__init__
- :param inner_strategy: set the inner strategy
+ :param inner_strategy: set the inner strategy.
"""
self.inner_strategy = inner_strategy
@@ -95,9 +98,10 @@ class AdjustTimer:
"""AdjustTimer
Responsible for timing of position adjusting
- This is designed as multiple inheritance mechanism due to
- - the is_adjust may need access to the internel state of a strategyw
- - it can be reguard as a enhancement to the existing strategy
+ This is designed as multiple inheritance mechanism due to:
+ - the is_adjust may need access to the internel state of a strategy.
+
+ - it can be reguard as a enhancement to the existing strategy.
"""
# adjust position in each trade date
@@ -136,26 +140,33 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
self.order_generator = order_generator_cls_or_obj
def generate_target_weight_position(self, score, current, trade_date):
- """Parameter:
- score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
- current : current position, use Position() class
+ """
+ Generate target position from score for this date and the current position.The cash is not considered in the position
+
+ Parameters
+ -----------
+ score : pd.Series
+ pred score for this trade date, index is stock_id, contain 'score' column.
+ current : Position()
+ current position.
trade_exchange : Exchange()
- trade_date : trade date
- generate target position from score for this date and the current position
- The cash is not considered in the position
+ trade_date : pd.Timestamp
+ trade date.
"""
raise NotImplementedError()
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
- """Parameter
+ """
+ Parameters
+ -----------
score_series : pd.Seires
- stock_id , score
+ stock_id , score.
current : Position()
- current of account
+ current of account.
trade_exchange : Exchange()
- exchange
+ exchange.
trade_date : pd.Timestamp
- date
+ date.
"""
# judge if to adjust
if not self.is_adjust(trade_date):
@@ -179,27 +190,49 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
- def __init__(self, topk, n_drop, method="bottom", risk_degree=0.95, thresh=1, hold_thresh=1, **kwargs):
- """Parameter
+ def __init__(
+ self,
+ topk,
+ n_drop,
+ method_sell="bottom",
+ method_buy="top",
+ risk_degree=0.95,
+ thresh=1,
+ hold_thresh=1,
+ only_tradable=False,
+ **kwargs,
+ ):
+ """
+ Parameters
+ -----------
topk : int
- The number of stocks in the portfolio
+ the number of stocks in the portfolio.
n_drop : int
- number of stocks to be replaced in each trading date
- method : str
- dropout method, random/bottom
+ number of stocks to be replaced in each trading date.
+ method_sell : str
+ dropout method_sell, random/bottom.
+ method_buy : str
+ dropout method_buy, random/top.
risk_degree : float
- position percentage of total value
+ position percentage of total value.
thresh : int
- minimun holding days since last buy singal of the stock
+ minimun holding days since last buy singal of the stock.
hold_thresh : int
minimum holding days
- before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh
+ before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh.
+ only_tradable : bool
+ will the strategy only consider the tradable stock when buying and selling.
+ if only_tradable:
+ strategy will make buy sell decision without checking the tradable state of the stock.
+ else:
+ strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
"""
super(TopkDropoutStrategy, self).__init__()
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
self.topk = topk
self.n_drop = n_drop
- self.method = method
+ self.method_sell = method_sell
+ self.method_buy = method_buy
self.risk_degree = risk_degree
self.thresh = thresh
# self.stock_count['code'] will be the days the stock has been hold
@@ -207,50 +240,114 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
self.stock_count = {}
self.hold_thresh = hold_thresh
+ self.only_tradable = only_tradable
def get_risk_degree(self, date):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
- Dynamically risk_degree will result in Market timing
+ Dynamically risk_degree will result in Market timing.
"""
# It will use 95% amoutn of your total value by default
return self.risk_degree
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
- """Gnererate order list according to score_series at trade_date.
- will not change current.
- Parameter
- score_series : pd.Seires
- stock_id , score
- current : Position()
- current of account
- trade_exchange : Exchange()
- exchange
- pred_date : pd.Timestamp
- predict date
- trade_date : pd.Timestamp
- trade date
+ """
+ Gnererate order list according to score_series at trade_date, will not change current.
+
+ Parameters
+ -----------
+ score_series : pd.Series
+ stock_id , score.
+ current : Position()
+ current of account.
+ trade_exchange : Exchange()
+ exchange.
+ pred_date : pd.Timestamp
+ predict date.
+ trade_date : pd.Timestamp
+ trade date.
"""
if not self.is_adjust(trade_date):
return []
+
+ if self.only_tradable:
+ # If The strategy only consider tradable stock when make decision
+ # It needs following actions to filter stocks
+ def get_first_n(l, n, reverse=False):
+ cur_n = 0
+ res = []
+ for si in reversed(l) if reverse else l:
+ if trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date):
+ res.append(si)
+ cur_n += 1
+ if cur_n >= n:
+ break
+ return res[::-1] if reverse else res
+
+ def get_last_n(l, n):
+ return get_first_n(l, n, reverse=True)
+
+ def filter_stock(l):
+ return [si for si in l if trade_exchange.is_stock_tradable(stock_id=si, trade_date=trade_date)]
+
+ else:
+ # Otherwise, the stock will make decision with out the stock tradable info
+ def get_first_n(l, n):
+ return list(l)[:n]
+
+ def get_last_n(l, n):
+ return list(l)[-n:]
+
+ def filter_stock(l):
+ return l
+
current_temp = copy.deepcopy(current)
# generate order list for this adjust date
sell_order_list = []
buy_order_list = []
# load score
+ cash = current_temp.get_cash()
current_stock_list = current_temp.get_stock_list()
+ # last position (sorted by score)
last = score_series.reindex(current_stock_list).sort_values(ascending=False).index
- today = (
- score_series[~score_series.index.isin(last)]
- .sort_values(ascending=False)
- .index[: self.n_drop + self.topk - len(last)]
- )
- comb = score_series.reindex(last.union(today)).sort_values(ascending=False).index
- if self.method == "bottom":
- sell = last[last.isin(comb[-self.n_drop :])]
- elif self.method == "random":
- sell = pd.Index(np.random.choice(last, self.n_drop) if len(last) else [])
+ # The new stocks today want to buy **at most**
+ if self.method_buy == "top":
+ today = get_first_n(
+ score_series[~score_series.index.isin(last)].sort_values(ascending=False).index,
+ self.n_drop + self.topk - len(last),
+ )
+ elif self.method_buy == "random":
+ topk_candi = get_first_n(score_series.sort_values(ascending=False).index, self.topk)
+ candi = list(filter(lambda x: x not in last, topk_candi))
+ n = self.n_drop + self.topk - len(last)
+ try:
+ today = np.random.choice(candi, n, replace=False)
+ except ValueError:
+ today = candi
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ # combine(new stocks + last stocks), we will drop stocks from this list
+ # In case of dropping higher score stock and buying lower score stock.
+ comb = score_series.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index
+
+ # Get the stock list we really want to sell (After filtering the case that we sell high and buy low)
+ if self.method_sell == "bottom":
+ sell = last[last.isin(get_last_n(comb, self.n_drop))]
+ elif self.method_sell == "random":
+ candi = filter_stock(last)
+ try:
+ sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else [])
+ except ValueError: # No enough candidates
+ sell = candi
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ # Get the stock list we really want to buy
buy = today[: len(sell) + self.topk - len(last)]
+
+ # buy singal: if a stock falls into topk, it appear in the buy_sinal
+ buy_signal = score_series.sort_values(ascending=False).iloc[: self.topk].index
+
for code in current_stock_list:
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
continue
@@ -274,12 +371,14 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
if trade_exchange.check_order(sell_order):
sell_order_list.append(sell_order)
trade_val, trade_cost, trade_price = trade_exchange.deal_order(sell_order, position=current_temp)
+ # update cash
+ cash += trade_val - trade_cost
# sold
del self.stock_count[code]
else:
# no buy signal, but the stock is kept
self.stock_count[code] += 1
- elif code in buy:
+ elif code in buy_signal:
# NOTE: This is different from the original version
# get new buy signal
# Only the stock fall in to topk will produce buy signal
@@ -289,7 +388,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
# buy new stock
# note the current has been changed
current_stock_list = current_temp.get_stock_list()
- value = current_temp.get_cash() * self.risk_degree / len(buy) if len(buy) > 0 else 0
+ value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it
# as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
diff --git a/qlib/data/_libs/expanding.pyx b/qlib/data/_libs/expanding.pyx
index 76b824c94..47bc49610 100644
--- a/qlib/data/_libs/expanding.pyx
+++ b/qlib/data/_libs/expanding.pyx
@@ -14,7 +14,7 @@ cdef class Expanding(object):
cdef int na_count
def __init__(self):
self.na_count = 0
-
+
cdef double update(self, double val):
pass
@@ -25,7 +25,7 @@ cdef class Mean(Expanding):
def __init__(self):
super(Mean, self).__init__()
self.vsum = 0
-
+
cdef double update(self, double val):
self.barv.push_back(val)
if isnan(val):
@@ -62,7 +62,7 @@ cdef class Slope(Expanding):
return (N*self.xy_sum - self.x_sum*self.y_sum) / \
(N*self.x2_sum - self.x_sum*self.x_sum)
-
+
cdef class Resi(Expanding):
"""1-D array expanding residuals"""
cdef double x_sum
@@ -94,7 +94,7 @@ cdef class Resi(Expanding):
interp = y_mean - slope*x_mean
return val - (slope*size + interp)
-
+
cdef class Rsquare(Expanding):
"""1-D array expanding rsquare"""
cdef double x_sum
@@ -117,7 +117,7 @@ cdef class Rsquare(Expanding):
self.na_count += 1
else:
self.x_sum += size
- self.x2_sum += size
+ self.x2_sum += size * size
self.y_sum += val
self.y2_sum += val * val
self.xy_sum += size * val
@@ -126,7 +126,7 @@ cdef class Rsquare(Expanding):
sqrt((N*self.x2_sum - self.x_sum*self.x_sum) * (N*self.y2_sum - self.y_sum*self.y_sum))
return rvalue * rvalue
-
+
cdef np.ndarray[double, ndim=1] expanding(Expanding r, np.ndarray a):
cdef int i
cdef int N = len(a)
diff --git a/qlib/data/base.py b/qlib/data/base.py
index c357700c0..92fc57ffe 100644
--- a/qlib/data/base.py
+++ b/qlib/data/base.py
@@ -6,12 +6,10 @@ from __future__ import division
from __future__ import print_function
import abc
-import six
import pandas as pd
-@six.add_metaclass(abc.ABCMeta)
-class Expression(object):
+class Expression(abc.ABC):
"""Expression base class"""
def __str__(self):
@@ -131,13 +129,13 @@ class Expression(object):
Parameters
----------
instrument : str
- instrument code
+ instrument code.
start_index : str
- feature start index [in calendar]
+ feature start index [in calendar].
end_index : str
- feature end index [in calendar]
+ feature end index [in calendar].
freq : str
- feature frequency
+ feature frequency.
Returns
----------
@@ -218,7 +216,6 @@ class Feature(Expression):
return 0, 0
-@six.add_metaclass(abc.ABCMeta)
class ExpressionOps(Expression):
"""Operator Expression
diff --git a/qlib/data/cache.py b/qlib/data/cache.py
index 3cfb8dae9..3fab2b527 100644
--- a/qlib/data/cache.py
+++ b/qlib/data/cache.py
@@ -76,8 +76,8 @@ class MemCache(object):
Parameters
----------
- mem_cache_size_limit: cache max size
- limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof)
+ mem_cache_size_limit: cache max size.
+ limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof).
"""
if limit_type not in ["length", "sizeof"]:
raise ValueError(f"limit_type must be length or sizeof, your limit_type is {limit_type}")
@@ -118,9 +118,9 @@ class MemCacheExpire:
def set_cache(mem_cache, key, value):
"""set cache
- :param mem_cache: MemCache attribute('c'/'i'/'f')
- :param key: cache key
- :param value: cache value
+ :param mem_cache: MemCache attribute('c'/'i'/'f').
+ :param key: cache key.
+ :param value: cache value.
"""
mem_cache[key] = value, time.time()
@@ -128,9 +128,9 @@ class MemCacheExpire:
def get_cache(mem_cache, key):
"""get mem cache
- :param mem_cache: MemCache attribute('c'/'i'/'f')
- :param key: cache key
- :return: cache value; if cache not exist, return None
+ :param mem_cache: MemCache attribute('c'/'i'/'f').
+ :param key: cache key.
+ :return: cache value; if cache not exist, return None.
"""
value = None
expire = False
@@ -180,6 +180,7 @@ class CacheUtils(object):
> select {C.redis_task_db}
> del "lock:{repr(lock_name)[1:-1]}-wlock"
> quit
+ If the issue is not resolved, use "keys *" to find if multiple keys exist. If so, try using "flushall" to clear all the keys.
"""
)
@@ -274,12 +275,12 @@ class ExpressionCache(BaseProviderCache):
Parameters
----------
cache_uri : str
- the complete uri of expression cache file (include dir path)
+ the complete uri of expression cache file (include dir path).
Returns
-------
int
- 0(successful update)/ 1(no need to update)/ 2(update failure)
+ 0(successful update)/ 1(no need to update)/ 2(update failure).
"""
raise NotImplementedError("Implement this method if you want to make expression cache up to date")
@@ -347,7 +348,7 @@ class DatasetCache(BaseProviderCache):
Parameters
----------
cache_uri : str
- the complete uri of dataset cache file (include dir path)
+ the complete uri of dataset cache file (include dir path).
Returns
-------
@@ -360,9 +361,9 @@ class DatasetCache(BaseProviderCache):
def cache_to_origin_data(data, fields):
"""cache data to origin data
- :param data: pd.DataFrame, cache data
- :param fields: feature fields
- :return: pd.DataFrame
+ :param data: pd.DataFrame, cache data.
+ :param fields: feature fields.
+ :return: pd.DataFrame.
"""
not_space_fields = remove_fields_space(fields)
data = data.loc[:, not_space_fields]
@@ -582,7 +583,7 @@ class DiskDatasetCache(DatasetCache):
:param cache_path:
:param start_time:
:param end_time:
- :param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent
+ :param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent.
:return:
"""
@@ -747,7 +748,8 @@ class DiskDatasetCache(DatasetCache):
The format the cache contains 3 parts(followed by typical filename).
- - index : cache/d41366901e25de3ec47297f12e2ba11d.index
+ - index : cache/d41366901e25de3ec47297f12e2ba11d.index
+
- The content of the file may be in following format(pandas.Series)
.. code-block:: python
@@ -764,15 +766,17 @@ class DiskDatasetCache(DatasetCache):
- It indicates the `end_index` of the data for `timestamp`
- meta data: cache/d41366901e25de3ec47297f12e2ba11d.meta
+
- data : cache/d41366901e25de3ec47297f12e2ba11d
+
- This is a hdf file sorted by datetime
- :param cache_path: The path to store the cache
- :param instruments: The instruments to store the cache
- :param fields: The fields to store the cache
- :param freq: The freq to store the cache
+ :param cache_path: The path to store the cache.
+ :param instruments: The instruments to store the cache.
+ :param fields: The fields to store the cache.
+ :param freq: The freq to store the cache.
- :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function
+ :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
"""
# get calendar
from .data import Cal
diff --git a/qlib/data/client.py b/qlib/data/client.py
index 2e83726d1..65a830f20 100644
--- a/qlib/data/client.py
+++ b/qlib/data/client.py
@@ -7,7 +7,7 @@ from __future__ import print_function
import socketio
-from .. import __version__
+import qlib
from ..log import get_module_logger
import pickle
@@ -51,15 +51,15 @@ class Client(object):
Parameters
----------
request_type : str
- type of proposed request, 'calendar'/'instrument'/'feature'
+ type of proposed request, 'calendar'/'instrument'/'feature'.
request_content : dict
- records the information of the request
+ records the information of the request.
msg_proc_func : func
- the function to process the message when receiving response, should have arg `*args`
+ the function to process the message when receiving response, should have arg `*args`.
msg_queue: Queue
- The queue to pass the messsage after callback
+ The queue to pass the messsage after callback.
"""
- head_info = {"version": __version__}
+ head_info = {"version": qlib.__version__}
def request_callback(*args):
"""callback_wrapper
diff --git a/qlib/data/data.py b/qlib/data/data.py
index b630834ef..a4c3d63f2 100644
--- a/qlib/data/data.py
+++ b/qlib/data/data.py
@@ -7,7 +7,6 @@ from __future__ import print_function
import os
import abc
-import six
import time
import queue
import bisect
@@ -16,6 +15,7 @@ import importlib
import traceback
import numpy as np
import pandas as pd
+from pathlib import Path
from multiprocessing import Pool
from .cache import H
@@ -25,10 +25,10 @@ from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
+from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
-@six.add_metaclass(abc.ABCMeta)
-class CalendarProvider(object):
+class CalendarProvider(abc.ABC):
"""Calendar provider base class
Provide calendar data.
@@ -41,13 +41,13 @@ class CalendarProvider(object):
Parameters
----------
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency, available: year/quarter/month/week/day
+ time frequency, available: year/quarter/month/week/day.
future : bool
- whether including future trading day
+ whether including future trading day.
Returns
----------
@@ -62,24 +62,24 @@ class CalendarProvider(object):
Parameters
----------
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency, available: year/quarter/month/week/day
+ time frequency, available: year/quarter/month/week/day.
future : bool
- whether including future trading day
+ whether including future trading day.
Returns
-------
pd.Timestamp
- the real start time
+ the real start time.
pd.Timestamp
- the real end time
+ the real end time.
int
- the index of start time
+ the index of start time.
int
- the index of end time
+ the index of end time.
"""
start_time = pd.Timestamp(start_time)
end_time = pd.Timestamp(end_time)
@@ -103,16 +103,16 @@ class CalendarProvider(object):
Parameters
----------
freq : str
- frequency of read calendar file
+ frequency of read calendar file.
future : bool
- whether including future trading day
+ whether including future trading day.
Returns
-------
list
- list of timestamps
+ list of timestamps.
dict
- dict composed by timestamp as key and index as value for fast search
+ dict composed by timestamp as key and index as value for fast search.
"""
flag = f"{freq}_future_{future}"
if flag in H["c"]:
@@ -128,8 +128,7 @@ class CalendarProvider(object):
return hash_args(start_time, end_time, freq, future)
-@six.add_metaclass(abc.ABCMeta)
-class InstrumentProvider(object):
+class InstrumentProvider(abc.ABC):
"""Instrument provider base class
Provide instrument data.
@@ -142,27 +141,30 @@ class InstrumentProvider(object):
Parameters
----------
market : str
- market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500
+ market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500.
filter_pipe : list
- the list of dynamic filters
+ the list of dynamic filters.
Returns
----------
dict
- dict of stockpool config
+ dict of stockpool config.
{`market`=>base market name, `filter_pipe`=>list of filters}
example :
- {'market': 'csi500',
- 'filter_pipe': [{'filter_type': 'ExpressionDFilter',
- 'rule_expression': '$open<40',
- 'filter_start_time': None,
- 'filter_end_time': None,
- 'keep': False},
- {'filter_type': 'NameDFilter',
- 'name_rule_re': 'SH[0-9]{4}55',
- 'filter_start_time': None,
- 'filter_end_time': None}]}
+
+ .. code-block::
+
+ {'market': 'csi500',
+ 'filter_pipe': [{'filter_type': 'ExpressionDFilter',
+ 'rule_expression': '$open<40',
+ 'filter_start_time': None,
+ 'filter_end_time': None,
+ 'keep': False},
+ {'filter_type': 'NameDFilter',
+ 'name_rule_re': 'SH[0-9]{4}55',
+ 'filter_start_time': None,
+ 'filter_end_time': None}]}
"""
if filter_pipe is None:
filter_pipe = []
@@ -180,13 +182,13 @@ class InstrumentProvider(object):
Parameters
----------
instruments : dict
- stockpool config
+ stockpool config.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
as_list : bool
- return instruments as list or dict
+ return instruments as list or dict.
Returns
-------
@@ -213,9 +215,22 @@ class InstrumentProvider(object):
return cls.LIST
raise ValueError(f"Unknown instrument type {inst}")
+ def convert_instruments(self, instrument):
+ _instruments_map = getattr(self, "_instruments_map", None)
+ if _instruments_map is None:
+ _df_list = []
+ # FIXME: each process will read these files
+ for _path in Path(C.get_data_path()).joinpath("instruments").glob("*.txt"):
+ _df = pd.read_csv(_path, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
+ _df_list.append(_df.iloc[:, [0, -1]])
+ df = pd.concat(_df_list, sort=False).sort_values("save_inst")
+ df = df.drop_duplicates(subset=["save_inst"], keep="first").fillna(axis=1, method="ffill")
+ _instruments_map = df.set_index("inst").iloc[:, 0].to_dict()
+ setattr(self, "_instruments_map", _instruments_map)
+ return _instruments_map.get(instrument, instrument)
-@six.add_metaclass(abc.ABCMeta)
-class FeatureProvider(object):
+
+class FeatureProvider(abc.ABC):
"""Feature provider class
Provide feature data.
@@ -228,15 +243,15 @@ class FeatureProvider(object):
Parameters
----------
instrument : str
- a certain instrument
+ a certain instrument.
field : str
- a certain field of feature
+ a certain field of feature.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency, available: year/quarter/month/week/day
+ time frequency, available: year/quarter/month/week/day.
Returns
-------
@@ -246,8 +261,7 @@ class FeatureProvider(object):
raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method")
-@six.add_metaclass(abc.ABCMeta)
-class ExpressionProvider(object):
+class ExpressionProvider(abc.ABC):
"""Expression provider class
Provide Expression data.
@@ -280,15 +294,15 @@ class ExpressionProvider(object):
Parameters
----------
instrument : str
- a certain instrument
+ a certain instrument.
field : str
- a certain field of feature
+ a certain field of feature.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency, available: year/quarter/month/week/day
+ time frequency, available: year/quarter/month/week/day.
Returns
-------
@@ -298,8 +312,7 @@ class ExpressionProvider(object):
raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method")
-@six.add_metaclass(abc.ABCMeta)
-class DatasetProvider(object):
+class DatasetProvider(abc.ABC):
"""Dataset provider class
Provide Dataset data.
@@ -312,20 +325,20 @@ class DatasetProvider(object):
Parameters
----------
instruments : list or dict
- list/dict of instruments or dict of stockpool config
+ list/dict of instruments or dict of stockpool config.
fields : list
- list of feature instances
+ list of feature instances.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency
+ time frequency.
Returns
----------
pd.DataFrame
- a pandas dataframe with index
+ a pandas dataframe with index.
"""
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")
@@ -344,17 +357,17 @@ class DatasetProvider(object):
Parameters
----------
instruments : list or dict
- list/dict of instruments or dict of stockpool config
+ list/dict of instruments or dict of stockpool config.
fields : list
- list of feature instances
+ list of feature instances.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency
+ time frequency.
disk_cache : int
- whether to skip(0)/use(1)/replace(2) disk_cache
+ whether to skip(0)/use(1)/replace(2) disk_cache.
"""
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
@@ -513,7 +526,7 @@ class LocalCalendarProvider(CalendarProvider):
Parameters
----------
freq : str
- frequency of read calendar file
+ frequency of read calendar file.
Returns
----------
@@ -575,19 +588,11 @@ class LocalInstrumentProvider(InstrumentProvider):
if not os.path.exists(fname):
raise ValueError("instruments not exists for market " + market)
_instruments = dict()
- with open(fname) as f:
- for line in f:
- inst_time = line.strip().split()
- inst = inst_time[0]
- if len(inst_time) == 3:
- # `day`
- begin = inst_time[1]
- end = inst_time[2]
- elif len(inst_time) == 5:
- # `1min`
- begin = inst_time[1] + " " + inst_time[2]
- end = inst_time[3] + " " + inst_time[4]
- _instruments.setdefault(inst, []).append((pd.Timestamp(begin), pd.Timestamp(end)))
+ df = pd.read_csv(fname, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
+ df["start_datetime"] = pd.to_datetime(df["start_datetime"])
+ df["end_datetime"] = pd.to_datetime(df["end_datetime"])
+ for row in df.itertuples(index=False):
+ _instruments.setdefault(row[0], []).append((row[1], row[2]))
return _instruments
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
@@ -642,10 +647,11 @@ class LocalFeatureProvider(FeatureProvider):
def feature(self, instrument, field, start_index, end_index, freq):
# validate
field = str(field).lower()[1:]
+ instrument = Inst.convert_instruments(instrument)
uri_data = self._uri_data.format(instrument.lower(), field, freq)
if not os.path.exists(uri_data):
get_module_logger("data").warning("WARN: data not found for %s.%s" % (instrument, field))
- return pd.Series()
+ return pd.Series(dtype=np.float32)
# raise ValueError('uri_data not found: ' + uri_data)
# load
series = read_bin(uri_data, start_index, end_index)
@@ -669,9 +675,10 @@ class LocalExpressionProvider(ExpressionProvider):
lft_etd, rght_etd = expression.get_extended_window_size()
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
# Ensure that each column type is consistent
- # FIXME: The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
+ # FIXME:
+ # 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
+ # 2) The the precision should be configurable
try:
- # TODO: the default storage and calculation type should be configurable
series = series.astype(np.float32)
except ValueError:
pass
@@ -952,6 +959,8 @@ class BaseProvider:
disk_cache=None,
):
"""
+ Parameters:
+ -----------
disk_cache : int
whether to skip(0)/use(1)/replace(2) disk_cache
@@ -1026,44 +1035,6 @@ class ClientProvider(BaseProvider):
DatasetD.set_conn(self.client)
-class Wrapper(object):
- """Data Provider Wrapper"""
-
- def __init__(self):
- self._provider = None
-
- def register(self, provider):
- self._provider = provider
-
- def __getattr__(self, key):
- if self._provider is None:
- raise AttributeError("Please run qlib.init() first using qlib")
- return getattr(self._provider, key)
-
-
-def get_cls_from_name(cls_name):
- return getattr(importlib.import_module(".data", package="qlib"), cls_name)
-
-
-def get_provider_obj(config, **params):
- if isinstance(config, dict):
- params.update(config["kwargs"])
- config = config["class"]
- return get_cls_from_name(config)(**params)
-
-
-def register_wrapper(wrapper, cls_or_obj):
- """register_wrapper
-
- :param wrapper: A wrapper of all kinds of providers
- :param cls_or_obj: A class or class name or object instance in data/data.py
- """
- if isinstance(cls_or_obj, str):
- cls_or_obj = get_cls_from_name(cls_or_obj)
- obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
- wrapper.register(obj)
-
-
Cal = Wrapper()
Inst = Wrapper()
FeatureD = Wrapper()
@@ -1075,34 +1046,35 @@ D = Wrapper()
def register_all_wrappers():
"""register_all_wrappers"""
logger = get_module_logger("data")
+ module = get_module_by_module_path("qlib.data")
- _calendar_provider = get_provider_obj(C.calendar_provider)
+ _calendar_provider = init_instance_by_config(C.calendar_provider, module)
if getattr(C, "calendar_cache", None) is not None:
- _calendar_provider = get_provider_obj(C.calendar_cache, provider=_calendar_provider)
- register_wrapper(Cal, _calendar_provider)
+ _calendar_provider = init_instance_by_config(C.calendar_cache, module, provide=_calendar_provider)
+ register_wrapper(Cal, _calendar_provider, "qlib.data")
logger.debug(f"registering Cal {C.calendar_provider}-{C.calenar_cache}")
- register_wrapper(Inst, C.instrument_provider)
+ register_wrapper(Inst, C.instrument_provider, "qlib.data")
logger.debug(f"registering Inst {C.instrument_provider}")
if getattr(C, "feature_provider", None) is not None:
- feature_provider = get_provider_obj(C.feature_provider)
- register_wrapper(FeatureD, feature_provider)
+ feature_provider = init_instance_by_config(C.feature_provider, module)
+ register_wrapper(FeatureD, feature_provider, "qlib.data")
logger.debug(f"registering FeatureD {C.feature_provider}")
if getattr(C, "expression_provider", None) is not None:
# This provider is unnecessary in client provider
- _eprovider = get_provider_obj(C.expression_provider)
+ _eprovider = init_instance_by_config(C.expression_provider, module)
if getattr(C, "expression_cache", None) is not None:
- _eprovider = get_provider_obj(C.expression_cache, provider=_eprovider)
- register_wrapper(ExpressionD, _eprovider)
+ _eprovider = init_instance_by_config(C.expression_cache, module, provider=_eprovider)
+ register_wrapper(ExpressionD, _eprovider, "qlib.data")
logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}")
- _dprovider = get_provider_obj(C.dataset_provider)
+ _dprovider = init_instance_by_config(C.dataset_provider, module)
if getattr(C, "dataset_cache", None) is not None:
- _dprovider = get_provider_obj(C.dataset_cache, provider=_dprovider)
- register_wrapper(DatasetD, _dprovider)
+ _dprovider = init_instance_by_config(C.dataset_cache, module, provider=_dprovider)
+ register_wrapper(DatasetD, _dprovider, "qlib.data")
logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}")
- register_wrapper(D, C.provider)
+ register_wrapper(D, C.provider, "qlib.data")
logger.debug(f"registering D {C.provider}")
diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py
new file mode 100644
index 000000000..e7d296d73
--- /dev/null
+++ b/qlib/data/dataset/__init__.py
@@ -0,0 +1,165 @@
+from ...utils.serial import Serializable
+from typing import Union, List, Tuple
+from ...utils import init_instance_by_config
+from ...log import get_module_logger
+from .handler import DataHandler, DataHandlerLP
+from inspect import getfullargspec
+import pandas as pd
+
+
+class Dataset(Serializable):
+ """
+ Preparing data for model training and inferencing.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """
+ init is designed to finish following steps:
+
+ - setup data
+ - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
+
+ - initialize the state of the dataset(info to prepare the data)
+ - The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
+
+ The data could specify the info to caculate the essential data for preparation
+ """
+ self.setup_data(*args, **kwargs)
+ super().__init__()
+
+ def setup_data(self, *args, **kwargs):
+ """
+ Setup the data.
+
+ We split the setup_data function for following situation:
+
+ - User have a Dataset object with learned status on disk.
+
+ - User load the Dataset object from the disk(Note the init function is skiped).
+
+ - User call `setup_data` to load new data.
+
+ - User prepare data for model based on previous status.
+ """
+ pass
+
+ def prepare(self, *args, **kwargs) -> object:
+ """
+ The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
+ The parameters should specify the scope for the prepared data
+ The method should:
+ - process the data
+
+ - return the processed data
+
+ Returns
+ -------
+ object:
+ return the object
+ """
+ pass
+
+
+class DatasetH(Dataset):
+ """
+ Dataset with Data(H)andler
+
+ User should try to put the data preprocessing functions into handler.
+ Only following data processing functions should be placed in Dataset:
+
+ - The processing is related to specific model.
+
+ - The processing is related to data split.
+ """
+
+ def __init__(self, handler: Union[dict, DataHandler], segments: list):
+ """
+ Parameters
+ ----------
+ handler : Union[dict, DataHandler]
+ handler will be passed into setup_data.
+ segments : list
+ handler will be passed into setup_data.
+ """
+ super().__init__(handler, segments)
+
+ def setup_data(self, handler: Union[dict, DataHandler], segments: list):
+ """
+ Setup the underlying data.
+
+ Parameters
+ ----------
+ handler : Union[dict, DataHandler]
+ handler could be:
+
+ - insntance of `DataHandler`
+
+ - config of `DataHandler`. Please refer to `DataHandler`
+
+ segments : list
+ Describe the options to segment the data.
+ Here are some examples:
+
+ .. code-block::
+
+ 1) 'segments': {
+ 'train': ("2008-01-01", "2014-12-31"),
+ 'valid': ("2017-01-01", "2020-08-01",),
+ 'test': ("2015-01-01", "2016-12-31",),
+ }
+ 2) 'segments': {
+ 'insample': ("2008-01-01", "2014-12-31"),
+ 'outsample': ("2017-01-01", "2020-08-01",),
+ }
+ """
+ self._handler = init_instance_by_config(handler, accept_types=DataHandler)
+ self._segments = segments.copy()
+
+ def prepare(
+ self,
+ segments: Union[List[str], Tuple[str], str, slice],
+ col_set=DataHandler.CS_ALL,
+ data_key=DataHandlerLP.DK_I,
+ **kwargs,
+ ) -> Union[List[pd.DataFrame], pd.DataFrame]:
+ """
+ Prepare the data for learning and inference.
+
+ Parameters
+ ----------
+ segments : Union[List[str], Tuple[str], str, slice]
+ Describe the scope of the data to be prepared
+ Here are some examples:
+
+ - 'train'
+
+ - ['train', 'valid']
+
+ col_set : str
+ The col_set will be passed to self._handler when fetching data.
+ data_key : str
+ The data to fetch: DK_*
+ Default is DK_I, which indicate fetching data for **inference**.
+
+ Returns
+ -------
+ Union[List[pd.DataFrame], pd.DataFrame]:
+
+ Raises
+ ------
+ NotImplementedError:
+ """
+ logger = get_module_logger("DatasetH")
+ fetch_kwargs = {"col_set": col_set}
+ fetch_kwargs.update(kwargs)
+ if "data_key" in getfullargspec(self._handler.fetch).args:
+ fetch_kwargs["data_key"] = data_key
+ else:
+ logger.info(f"data_key[{data_key}] is ignored.")
+
+ if isinstance(segments, (list, tuple)):
+ return [self._handler.fetch(slice(*self._segments[seg]), **fetch_kwargs) for seg in segments]
+ elif isinstance(segments, str):
+ return self._handler.fetch(slice(*self._segments[segments]), **fetch_kwargs)
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py
new file mode 100644
index 000000000..905fcd623
--- /dev/null
+++ b/qlib/data/dataset/handler.py
@@ -0,0 +1,467 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+# coding=utf-8
+import abc
+import bisect
+import logging
+import warnings
+from typing import Union, Tuple, List, Iterator, Optional
+
+import pandas as pd
+import numpy as np
+
+from ...log import get_module_logger, TimeInspector
+from ...data import D
+from ...config import C
+from ...utils import parse_config, transform_end_date, init_instance_by_config
+from ...utils.serial import Serializable
+from .utils import get_level_index, fetch_df_by_index
+from pathlib import Path
+from .loader import DataLoader
+
+from . import processor as processor_module
+from . import loader as data_loader_module
+
+
+# TODO: A more general handler interface which does not relies on internal pd.DataFrame is needed.
+class DataHandler(Serializable):
+ """
+ The steps to using a handler
+ 1. initialized data handler (call by `init`).
+ 2. use the data.
+
+
+ The data handler try to maintain a handler with 2 level.
+ `datetime` & `instruments`.
+
+ Any order of the index level can be suported(The order will implied in the data).
+ The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
+
+ Example of the data:
+ The multi-index of the columns is optional.
+
+ .. code-block:: python
+
+ feature label
+ $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
+ datetime instrument
+ 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
+ SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
+ SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
+
+ """
+
+ def __init__(
+ self,
+ instruments=None,
+ start_time=None,
+ end_time=None,
+ data_loader: Tuple[dict, str, DataLoader] = None,
+ init_data=True,
+ fetch_orig=True,
+ ):
+ """
+ Parameters
+ ----------
+ instruments :
+ The stock list to retrive.
+ start_time :
+ start_time of the original data.
+ end_time :
+ end_time of the original data.
+ data_loader : Tuple[dict, str, DataLoader]
+ data loader to load the data.
+ init_data :
+ intialize the original data in the constructor.
+ fetch_orig : bool
+ Return the original data instead of copy if possible.
+ """
+ # Set logger
+ self.logger = get_module_logger("DataHandler")
+
+ # Setup data loader
+ assert data_loader is not None # to make start_time end_time could have None default value
+
+ self.data_loader = init_instance_by_config(
+ data_loader,
+ None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module,
+ accept_types=DataLoader,
+ )
+
+ self.instruments = instruments
+ self.start_time = start_time
+ self.end_time = end_time
+ self.fetch_orig = fetch_orig
+ if init_data:
+ with TimeInspector.logt("Init data"):
+ self.init()
+ super().__init__()
+
+ def init(self, enable_cache: bool = True):
+ """
+ initialize the data.
+ In case of running intialization for multiple time, it will do nothing for the second time.
+
+ It is responsible for maintaining following variable
+ 1) self._data
+
+ Parameters
+ ----------
+ enable_cache : bool
+ default value is false:
+
+ - if `enable_cache` == True:
+
+ the processed data will be saved on disk, and handler will load the cached data from the disk directly
+ when we call `init` next time
+ """
+ # Setup data.
+ # _data may be with multiple column index level. The outer level indicates the feature set name
+ with TimeInspector.logt("Loading data"):
+ self._data = self.data_loader.load(self.instruments, self.start_time, self.end_time)
+ # TODO: cache
+
+ CS_ALL = "__all" # return all columns with single-level index column
+ CS_RAW = "__raw" # return raw data with multi-level index column
+
+ def _fetch_df_by_col(self, df: pd.DataFrame, col_set: str) -> pd.DataFrame:
+ if not isinstance(df.columns, pd.MultiIndex) or col_set == self.CS_RAW:
+ return df
+ elif col_set == self.CS_ALL:
+ return df.droplevel(axis=1, level=0)
+ else:
+ return df.loc(axis=1)[col_set]
+
+ def fetch(
+ self,
+ selector: Union[pd.Timestamp, slice, str] = slice(None, None),
+ level: Union[str, int] = "datetime",
+ col_set: Union[str, List[str]] = CS_ALL,
+ squeeze: bool = False,
+ ) -> pd.DataFrame:
+ """
+ fetch data from underlying data source
+
+ Parameters
+ ----------
+ selector : Union[pd.Timestamp, slice, str]
+ describe how to select data by index
+ level : Union[str, int]
+ which index level to select the data
+ col_set : Union[str, List[str]]
+
+ - if isinstance(col_set, str):
+
+ select a set of meaningful columns.(e.g. features, columns)
+
+ - if isinstance(col_set, List[str]):
+
+ select several sets of meaningful columns, the returned data has multiple levels
+
+ squeeze : bool
+ whether squeeze columns and index
+
+ Returns
+ -------
+ pd.DataFrame.
+ """
+ # Fetch column first will be more friendly to SepDataFrame
+ df = self._fetch_df_by_col(self._data, col_set)
+ df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
+ if squeeze:
+ # squeeze columns
+ df = df.squeeze()
+ # squeeze index
+ if isinstance(selector, (str, pd.Timestamp)):
+ df = df.reset_index(level=level, drop=True)
+ return df
+
+ def get_cols(self, col_set=CS_ALL) -> list:
+ """
+ get the column names
+
+ Parameters
+ ----------
+ col_set : str
+ select a set of meaningful columns.(e.g. features, columns)
+
+ Returns
+ -------
+ list:
+ list of column names
+ """
+ df = self._data.head()
+ df = self._fetch_df_by_col(df, col_set)
+ return df.columns.to_list()
+
+ def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice:
+ """
+ get range selector by number of periods
+
+ Args:
+ cur_date (pd.Timestamp or str): current date
+ periods (int): number of periods
+ """
+ trading_dates = self._data.index.unique(level="datetime")
+ cur_loc = trading_dates.get_loc(cur_date)
+ pre_loc = cur_loc - periods + 1
+ if pre_loc < 0:
+ warnings.warn("`periods` is too large. the first date will be returned.")
+ pre_loc = 0
+ ref_date = trading_dates[pre_loc]
+ return slice(ref_date, cur_date)
+
+ def get_range_iterator(
+ self, periods: int, min_periods: Optional[int] = None, **kwargs
+ ) -> Iterator[Tuple[pd.Timestamp, pd.DataFrame]]:
+ """
+ get a iterator of sliced data with given periods
+
+ Args:
+ periods (int): number of periods.
+ min_periods (int): minimum periods for sliced dataframe.
+ kwargs (dict): will be passed to `self.fetch`.
+ """
+ trading_dates = self._data.index.unique(level="datetime")
+ if min_periods is None:
+ min_periods = periods
+ for cur_date in trading_dates[min_periods:]:
+ selector = self.get_range_selector(cur_date, periods)
+ yield cur_date, self.fetch(selector, **kwargs)
+
+
+class DataHandlerLP(DataHandler):
+ """
+ DataHandler with **(L)earnable (P)rocessor**
+ """
+
+ # data key
+ DK_R = "raw"
+ DK_I = "infer"
+ DK_L = "learn"
+
+ # process type
+ PTYPE_I = "independent"
+ # - self._infer will be processed by infer_processors
+ # - self._learn will be processed by learn_processors
+ PTYPE_A = "append"
+ # - self._infer will be processed by infer_processors
+ # - self._learn will be processed by infer_processors + learn_processors
+ # - (e.g. self._infer processed by learn_processors )
+
+ def __init__(
+ self,
+ instruments=None,
+ start_time=None,
+ end_time=None,
+ data_loader: Tuple[dict, str, DataLoader] = None,
+ infer_processors=[],
+ learn_processors=[],
+ process_type=PTYPE_A,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ infer_processors : list
+ - list of of processors to generate data for inference
+
+ - example of :
+
+ .. code-block::
+
+ 1) classname & kwargs:
+ {
+ "class": "MinMaxNorm",
+ "kwargs": {
+ "fit_start_time": "20080101",
+ "fit_end_time": "20121231"
+ }
+ }
+ 2) Only classname:
+ "DropnaFeature"
+ 3) object instance of Processor
+
+ learn_processors : list
+ similar to infer_processors, but for generating data for learning models
+
+ process_type: str
+ PTYPE_I = 'independent'
+
+ - self._infer will processed by infer_processors
+
+ - self._learn will be processed by learn_processors
+
+ PTYPE_A = 'append'
+
+ - self._infer will processed by infer_processors
+
+ - self._learn will be processed by infer_processors + learn_processors
+
+ - (e.g. self._infer processed by learn_processors )
+ """
+
+ # Setup preprocessor
+ self.infer_processors = [] # for lint
+ self.learn_processors = [] # for lint
+ for pname in "infer_processors", "learn_processors":
+ for proc in locals()[pname]:
+ getattr(self, pname).append(
+ init_instance_by_config(
+ proc,
+ None if (isinstance(proc, dict) and "module_path" in proc) else processor_module,
+ accept_types=processor_module.Processor,
+ )
+ )
+
+ self.process_type = process_type
+ super().__init__(instruments, start_time, end_time, data_loader, **kwargs)
+
+ def get_all_processors(self):
+ return self.infer_processors + self.learn_processors
+
+ def fit(self):
+ for proc in self.get_all_processors():
+ with TimeInspector.logt(f"{proc.__class__.__name__}"):
+ proc.fit(self._data)
+
+ def fit_process_data(self):
+ """
+ fit and process data
+
+ The input of the `fit` will be the output of the previous processor
+ """
+ self.process_data(with_fit=True)
+
+ def process_data(self, with_fit: bool = False):
+ """
+ process_data data. Fun `processor.fit` if necessary
+
+ Parameters
+ ----------
+ with_fit : bool
+ The input of the `fit` will be the output of the previous processor
+ """
+ # data for inference
+ _infer_df = self._data
+ if len(self.infer_processors) > 0: # avoid modifying the original data
+ _infer_df = _infer_df.copy()
+
+ for proc in self.infer_processors:
+ if not proc.is_for_infer():
+ raise TypeError("Only processors usable for inference can be used in `infer_processors` ")
+ with TimeInspector.logt(f"{proc.__class__.__name__}"):
+ if with_fit:
+ proc.fit(_infer_df)
+ _infer_df = proc(_infer_df)
+ self._infer = _infer_df
+
+ # data for learning
+ if self.process_type == DataHandlerLP.PTYPE_I:
+ _learn_df = self._data
+ elif self.process_type == DataHandlerLP.PTYPE_A:
+ # based on `infer_df` and append the processor
+ _learn_df = _infer_df
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ if len(self.learn_processors) > 0: # avoid modifying the original data
+ _learn_df = _learn_df.copy()
+ for proc in self.learn_processors:
+ with TimeInspector.logt(f"{proc.__class__.__name__}"):
+ if with_fit:
+ proc.fit(_learn_df)
+ _learn_df = proc(_learn_df)
+ self._learn = _learn_df
+
+ # init type
+ IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
+ IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
+ IT_LS = "load_state" # The state of the object has been load by pickle
+
+ def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
+ """
+ Initialize the data of Qlib
+
+ Parameters
+ ----------
+ init_type : str
+ The type `IT_*` listed above.
+ enable_cache : bool
+ default value is false:
+
+ - if `enable_cache` == True:
+
+ the processed data will be saved on disk, and handler will load the cached data from the disk directly
+ when we call `init` next time
+ """
+ # init raw data
+ super().init(enable_cache=enable_cache)
+
+ with TimeInspector.logt("fit & process data"):
+ if init_type == DataHandlerLP.IT_FIT_IND:
+ self.fit()
+ self.process_data()
+ elif init_type == DataHandlerLP.IT_LS:
+ self.process_data()
+ elif init_type == DataHandlerLP.IT_FIT_SEQ:
+ self.fit_process_data()
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+ # TODO: Be able to cache handler data. Save the memory for data processing
+
+ def _get_df_by_key(self, data_key: str = DK_I) -> pd.DataFrame:
+ df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
+ return df
+
+ def fetch(
+ self,
+ selector: Union[pd.Timestamp, slice, str] = slice(None, None),
+ level: Union[str, int] = "datetime",
+ col_set=DataHandler.CS_ALL,
+ data_key: str = DK_I,
+ ) -> pd.DataFrame:
+ """
+ fetch data from underlying data source
+
+ Parameters
+ ----------
+ selector : Union[pd.Timestamp, slice, str]
+ describe how to select data by index.
+ level : Union[str, int]
+ which index level to select the data.
+ col_set : str
+ select a set of meaningful columns.(e.g. features, columns).
+ data_key : str
+ the data to fetch: DK_*.
+
+ Returns
+ -------
+ pd.DataFrame:
+ """
+ df = self._get_df_by_key(data_key)
+ # Fetch column first will be more friendly to SepDataFrame
+ df = self._fetch_df_by_col(df, col_set)
+ return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
+
+ def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
+ """
+ get the column names
+
+ Parameters
+ ----------
+ col_set : str
+ select a set of meaningful columns.(e.g. features, columns).
+ data_key : str
+ the data to fetch: DK_*.
+
+ Returns
+ -------
+ list:
+ list of column names
+ """
+ df = self._get_df_by_key(data_key).head()
+ df = self._fetch_df_by_col(df, col_set)
+ return df.columns.to_list()
diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py
new file mode 100644
index 000000000..a51ea119a
--- /dev/null
+++ b/qlib/data/dataset/loader.py
@@ -0,0 +1,198 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import os
+import abc
+import warnings
+import numpy as np
+import pandas as pd
+
+from typing import Tuple, Union
+
+from qlib.data import D
+from qlib.utils import load_dataset
+
+
+class DataLoader(abc.ABC):
+ """
+ DataLoader is designed for loading raw data from original data source.
+ """
+
+ @abc.abstractmethod
+ def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
+ """
+ load the data as pd.DataFrame.
+
+ Example of the data (The multi-index of the columns is optional.):
+
+ .. code-block:: python
+
+ feature label
+ $close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
+ datetime instrument
+ 2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
+ SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
+ SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
+
+
+ Parameters
+ ----------
+ instruments : str or dict
+ it can either be the market name or the config file of instruments generated by InstrumentProvider.
+ start_time : str
+ start of the time range.
+ end_time : str
+ end of the time range.
+
+ Returns
+ -------
+ pd.DataFrame:
+ data load from the under layer source
+ """
+ pass
+
+
+class DLWParser(DataLoader):
+ """
+ (D)ata(L)oader (W)ith (P)arser for features and names
+
+ Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
+ """
+
+ def __init__(self, config: Tuple[list, tuple, dict]):
+ """
+ Parameters
+ ----------
+ config : Tuple[list, tuple, dict]
+ Config will be used to describe the fields and column names
+
+ .. code-block::
+
+ := {
+ "group_name1":
+ "group_name2":
+ }
+ or
+ :=
+
+ := ["expr", ...] | (["expr", ...], ["col_name", ...])
+ """
+ self.is_group = isinstance(config, dict)
+
+ if self.is_group:
+ self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}
+ else:
+ self.fields = self._parse_fields_info(config)
+
+ def _parse_fields_info(self, fields_info: Tuple[list, tuple]) -> Tuple[list, list]:
+ if isinstance(fields_info, list):
+ exprs = names = fields_info
+ elif isinstance(fields_info, tuple):
+ exprs, names = fields_info
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ return exprs, names
+
+ @abc.abstractmethod
+ def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
+ """
+ load the dataframe for specific group
+
+ Parameters
+ ----------
+ instruments :
+ the instruments.
+ exprs : list
+ the expressions to describe the content of the data.
+ names : list
+ the name of the data.
+
+ Returns
+ -------
+ pd.DataFrame:
+ the queried dataframe.
+ """
+ pass
+
+ def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
+ if self.is_group:
+ df = pd.concat(
+ {
+ grp: self.load_group_df(instruments, exprs, names, start_time, end_time)
+ for grp, (exprs, names) in self.fields.items()
+ },
+ axis=1,
+ )
+ else:
+ exprs, names = self.fields
+ df = self.load_group_df(instruments, exprs, names, start_time, end_time)
+ return df
+
+
+class QlibDataLoader(DLWParser):
+ """Same as QlibDataLoader. The fields can be define by config"""
+
+ def __init__(self, config: Tuple[list, tuple, dict], filter_pipe=None):
+ """
+ Parameters
+ ----------
+ config : Tuple[list, tuple, dict]
+ Please refer to the doc of DLWParser
+ filter_pipe :
+ Filter pipe for the instruments
+ """
+ self.filter_pipe = filter_pipe
+ super().__init__(config)
+
+ def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
+ if instruments is None:
+ warnings.warn("`instruments` is not set, will load all stocks")
+ instruments = "all"
+ if isinstance(instruments, str):
+ instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
+ elif self.filter_pipe is not None:
+ warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")
+
+ df = D.features(instruments, exprs, start_time, end_time)
+ df.columns = names
+ df = df.swaplevel().sort_index() # NOTE: always return
+ return df
+
+
+class StaticDataLoader(DataLoader):
+ """
+ DataLoader that supports loading data from file or as provided.
+ """
+
+ def __init__(self, config: dict, join="outer"):
+ """
+ Parameters
+ ----------
+ config : dict
+ {fields_group: }
+ join : str
+ How to align different dataframes
+ """
+ self.config = config
+ self.join = join
+ self._data = None
+
+ def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
+ self._maybe_load_raw_data()
+ if instruments is None:
+ df = self._data
+ else:
+ df = self._data.loc(axis=0)[:, instruments]
+ if start_time is None and end_time is None:
+ return df # NOTE: avoid copy by loc
+ return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
+
+ def _maybe_load_raw_data(self):
+ if self._data is not None:
+ return
+ self._data = pd.concat(
+ {fields_group: load_dataset(path_or_obj) for fields_group, path_or_obj in self.config.items()},
+ axis=1,
+ join=self.join,
+ )
+ self._data.sort_index(inplace=True)
diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py
new file mode 100755
index 000000000..76cf85c4a
--- /dev/null
+++ b/qlib/data/dataset/processor.py
@@ -0,0 +1,285 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import abc
+import numpy as np
+import pandas as pd
+import copy
+
+from ...log import TimeInspector
+from .utils import fetch_df_by_index
+from ...utils.serial import Serializable
+from ...utils.paral import datetime_groupby_apply
+
+EPS = 1e-12
+
+
+def get_group_columns(df: pd.DataFrame, group: str):
+ """
+ get a group of columns from multi-index columns DataFrame
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ with multi of columns.
+ group : str
+ the name of the feature group, i.e. the first level value of the group index.
+ """
+ if group is None:
+ return df.columns
+ else:
+ return df.columns[df.columns.get_loc(group)]
+
+
+class Processor(Serializable):
+ def fit(self, df: pd.DataFrame = None):
+ """
+ learn data processing parameters
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ When we fit and process data with processor one by one. The fit function reiles on the output of previous
+ processor, i.e. `df`.
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def __call__(self, df: pd.DataFrame):
+ """
+ process the data
+
+ NOTE: **The processor could change the content of `df` inplace !!!!! **
+ User should keep a copy of data outside
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ The raw_df of handler or result from previous processor.
+ """
+ pass
+
+ def is_for_infer(self) -> bool:
+ """
+ Is this processor usable for inference
+ Some processors are not usable for inference.
+
+ Returns
+ -------
+ bool:
+ if it is usable for infenrece.
+ """
+ return True
+
+
+class DropnaProcessor(Processor):
+ def __init__(self, fields_group=None):
+ self.fields_group = fields_group
+
+ def __call__(self, df):
+ return df.dropna(subset=get_group_columns(df, self.fields_group))
+
+
+class DropnaLabel(DropnaProcessor):
+ def __init__(self, fields_group="label"):
+ super().__init__(fields_group=fields_group)
+
+ def is_for_infer(self) -> bool:
+ """The samples are dropped according to label. So it is not usable for inference"""
+ return False
+
+
+class DropCol(Processor):
+ def __init__(self, col_list=[]):
+ self.col_list = col_list
+
+ def __call__(self, df):
+ if isinstance(df.columns, pd.MultiIndex):
+ mask = df.columns.get_level_values(-1).isin(self.col_list)
+ else:
+ mask = df.columns.isin(self.col_list)
+ return df.loc[:, ~mask]
+
+
+class TanhProcess(Processor):
+ """ Use tanh to process noise data"""
+
+ def __call__(self, df):
+ def tanh_denoise(data):
+ mask = data.columns.get_level_values(1).str.contains("LABEL")
+ col = df.columns[~mask]
+ data[col] = data[col] - 1
+ data[col] = np.tanh(data[col])
+
+ return data
+
+ return tanh_denoise(df)
+
+
+class ProcessInf(Processor):
+ """Process infinity """
+
+ def __call__(self, df):
+ def replace_inf(data):
+ def process_inf(df):
+ for col in df.columns:
+ # FIXME: Such behavior is very weird
+ df[col] = df[col].replace([np.inf, -np.inf], df[col][~np.isinf(df[col])].mean())
+ return df
+
+ data = datetime_groupby_apply(data, process_inf)
+ data.sort_index(inplace=True)
+ return data
+
+ return replace_inf(df)
+
+
+class Fillna(Processor):
+ """Process NaN"""
+
+ def __init__(self, fields_group=None, fill_value=0):
+ self.fields_group = fields_group
+ self.fill_value = fill_value
+
+ def __call__(self, df):
+ if self.fields_group is None:
+ df.fillna(self.fill_value, inplace=True)
+ else:
+ cols = get_group_columns(df, self.fields_group)
+ df.fillna({col: self.fill_value for col in cols}, inplace=True)
+ return df
+
+
+class MinMaxNorm(Processor):
+ def __init__(self, fit_start_time, fit_end_time, fields_group=None):
+ self.fit_start_time = fit_start_time
+ self.fit_end_time = fit_end_time
+ self.fields_group = fields_group
+
+ def fit(self, df):
+ df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
+ cols = get_group_columns(df, self.fields_group)
+ self.min_val = np.nanmin(df[cols].values, axis=0)
+ self.max_val = np.nanmax(df[cols].values, axis=0)
+ self.ignore = self.min_val == self.max_val
+ self.cols = cols
+
+ def __call__(self, df):
+ def normalize(x, min_val=self.min_val, max_val=self.max_val, ignore=self.ignore):
+ if (~ignore).all():
+ return (x - min_val) / (max_val - min_val)
+ for i in range(ignore.size):
+ if not ignore[i]:
+ x[i] = (x[i] - min_val) / (max_val - min_val)
+ return x
+
+ df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
+ return df
+
+
+class ZScoreNorm(Processor):
+ """ZScore Normalization"""
+
+ def __init__(self, fit_start_time, fit_end_time, fields_group=None):
+ self.fit_start_time = fit_start_time
+ self.fit_end_time = fit_end_time
+ self.fields_group = fields_group
+
+ def fit(self, df):
+ df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
+ cols = get_group_columns(df, self.fields_group)
+ self.mean_train = np.nanmean(df[cols].values, axis=0)
+ self.std_train = np.nanstd(df[cols].values, axis=0)
+ self.ignore = self.std_train == 0
+ self.cols = cols
+
+ def __call__(self, df):
+ def normalize(x, mean_train=self.mean_train, std_train=self.std_train, ignore=self.ignore):
+ if (~ignore).all():
+ return (x - mean_train) / std_train
+ for i in range(ignore.size):
+ if not ignore[i]:
+ x[i] = (x[i] - mean_train) / std_train
+ return x
+
+ df.loc(axis=1)[self.cols] = normalize(df[self.cols].values)
+ return df
+
+
+class RobustZScoreNorm(Processor):
+ """Robust ZScore Normalization
+
+ Use robust statistics for Z-Score normalization:
+ mean(x) = median(x)
+ std(x) = MAD(x) * 1.4826
+
+ Reference:
+ https://en.wikipedia.org/wiki/Median_absolute_deviation.
+ """
+
+ def __init__(self, fit_start_time, fit_end_time, fields_group=None, clip_outlier=True):
+ self.fit_start_time = fit_start_time
+ self.fit_end_time = fit_end_time
+ self.fields_group = fields_group
+ self.clip_outlier = clip_outlier
+
+ def fit(self, df):
+ df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
+ self.cols = get_group_columns(df, self.fields_group)
+ X = df[self.cols].values
+ self.mean_train = np.nanmedian(X, axis=0)
+ self.std_train = np.nanmedian(np.abs(X - self.mean_train), axis=0)
+ self.std_train += EPS
+ self.std_train *= 1.4826
+
+ def __call__(self, df):
+ X = df[self.cols]
+ X -= self.mean_train
+ X /= self.std_train
+ df[self.cols] = X
+ if self.clip_outlier:
+ df.clip(-3, 3, inplace=True)
+ return df
+
+
+class CSZScoreNorm(Processor):
+ """Cross Sectional ZScore Normalization"""
+
+ def __init__(self, fields_group=None):
+ self.fields_group = fields_group
+
+ def __call__(self, df):
+ # try not modify original dataframe
+ cols = get_group_columns(df, self.fields_group)
+ df[cols] = df[cols].groupby("datetime").apply(lambda x: (x - x.mean()).div(x.std()))
+
+ return df
+
+
+class CSRankNorm(Processor):
+ """Cross Sectional Rank Normalization"""
+
+ def __init__(self, fields_group=None):
+ self.fields_group = fields_group
+
+ def __call__(self, df):
+ # try not modify original dataframe
+ cols = get_group_columns(df, self.fields_group)
+ t = df[cols].groupby("datetime").rank(pct=True)
+ t -= 0.5
+ t *= 3.46 # NOTE: towards unit std
+ df[cols] = t
+ return df
+
+
+class CSZFillna(Processor):
+ """Cross Sectional Fill Nan"""
+
+ def __init__(self, fields_group=None):
+ self.fields_group = fields_group
+
+ def __call__(self, df):
+ cols = get_group_columns(df, self.fields_group)
+ df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean()))
+ return df
diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py
new file mode 100644
index 000000000..feda19044
--- /dev/null
+++ b/qlib/data/dataset/utils.py
@@ -0,0 +1,72 @@
+from typing import Union
+import pandas as pd
+
+
+def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
+ """
+
+ get the level index of `df` given `level`
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ data
+ level : Union[str, int]
+ index level
+
+ Returns
+ -------
+ int:
+ The level index in the multiple index
+ """
+ if isinstance(level, str):
+ try:
+ return df.index.names.index(level)
+ except (AttributeError, ValueError):
+ # NOTE: If level index is not given in the data, the default level index will be ('datetime', 'instrument')
+ return ("datetime", "instrument").index(level)
+ elif isinstance(level, int):
+ return level
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+
+
+def fetch_df_by_index(
+ df: pd.DataFrame,
+ selector: Union[pd.Timestamp, slice, str, list],
+ level: Union[str, int],
+ fetch_orig=True,
+) -> pd.DataFrame:
+ """
+ fetch data from `data` with `selector` and `level`
+
+ Parameters
+ ----------
+ selector : Union[pd.Timestamp, slice, str, list]
+ selector
+ level : Union[int, str]
+ the level to use the selector
+
+ Returns
+ -------
+ Data of the given index.
+ """
+ # level = None -> use selector directly
+ if level == None:
+ return df.loc(axis=0)[selector]
+ # Try to get the right index
+ idx_slc = (selector, slice(None, None))
+ if get_level_index(df, level) == 1:
+ idx_slc = idx_slc[1], idx_slc[0]
+ if fetch_orig:
+ for slc in idx_slc:
+ if slc != slice(None, None):
+ return df.loc[
+ pd.IndexSlice[idx_slc],
+ ]
+ else:
+ return df
+ else:
+ return df.loc[
+ pd.IndexSlice[idx_slc],
+ ]
diff --git a/qlib/data/filter.py b/qlib/data/filter.py
index 3a36b1678..70f9d3278 100644
--- a/qlib/data/filter.py
+++ b/qlib/data/filter.py
@@ -7,14 +7,12 @@ from abc import abstractmethod
import re
import pandas as pd
import numpy as np
-import six
import abc
from .data import Cal, DatasetD
-@six.add_metaclass(abc.ABCMeta)
-class BaseDFilter(object):
+class BaseDFilter(abc.ABC):
"""Dynamic Instruments Filter Abstract class
Users can override this class to construct their own filter
@@ -34,7 +32,7 @@ class BaseDFilter(object):
Parameters
----------
config : dict
- dict of config parameters
+ dict of config parameters.
"""
raise NotImplementedError("Subclass of BaseDFilter must reimplement `from_config` method")
@@ -45,12 +43,11 @@ class BaseDFilter(object):
Returns
----------
dict
- return the dict of config parameters
+ return the dict of config parameters.
"""
raise NotImplementedError("Subclass of BaseDFilter must reimplement `to_config` method")
-@six.add_metaclass(abc.ABCMeta)
class SeriesDFilter(BaseDFilter):
"""Dynamic Instruments Filter Abstract class to filter a series of certain features
@@ -72,9 +69,9 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
fstart_time: str
- the time for the filter rule to start filter the instruments
+ the time for the filter rule to start filter the instruments.
fend_time: str
- the time for the filter rule to stop filter the instruments
+ the time for the filter rule to stop filter the instruments.
"""
super(SeriesDFilter, self).__init__()
self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None
@@ -86,12 +83,12 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
instruments: dict
- the dict of instruments in the form {instrument_name => list of timestamp tuple}
+ the dict of instruments in the form {instrument_name => list of timestamp tuple}.
Returns
----------
pd.Timestamp, pd.Timestamp
- the lower time bound and upper time bound of all the instruments
+ the lower time bound and upper time bound of all the instruments.
"""
trange = Cal.calendar(freq=self.filter_freq)
ubound, lbound = trange[0], trange[-1]
@@ -108,14 +105,14 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
time_range : D.calendar
- the time range of the instruments
+ the time range of the instruments.
target_timestamp : list
- the list of tuple (timestamp, timestamp)
+ the list of tuple (timestamp, timestamp).
Returns
----------
pd.Series
- the series of bool value for an instrument
+ the series of bool value for an instrument.
"""
# Construct a whole dict of {date => bool}
timestamp_series = {timestamp: False for timestamp in time_range}
@@ -127,19 +124,19 @@ class SeriesDFilter(BaseDFilter):
return timestamp_series
def _filterSeries(self, timestamp_series, filter_series):
- """Filter the timestamp series with filter series by using element-wise AND operation of the two series
+ """Filter the timestamp series with filter series by using element-wise AND operation of the two series.
Parameters
----------
timestamp_series : pd.Series
- the series of bool value indicating existing time
+ the series of bool value indicating existing time.
filter_series : pd.Series
- the series of bool value indicating filter feature
+ the series of bool value indicating filter feature.
Returns
----------
pd.Series
- the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp
+ the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp.
"""
fstart, fend = list(filter_series.keys())[0], list(filter_series.keys())[-1]
filter_series = filter_series.astype("bool") # Make sure the filter_series is boolean
@@ -147,17 +144,17 @@ class SeriesDFilter(BaseDFilter):
return timestamp_series
def _toTimestamp(self, timestamp_series):
- """Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE
+ """Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE.
Parameters
----------
timestamp_series: pd.Series
- the series of bool value after being filtered
+ the series of bool value after being filtered.
Returns
----------
list
- the list of tuple (timestamp, timestamp)
+ the list of tuple (timestamp, timestamp).
"""
# sort the timestamp_series according to the timestamps
timestamp_series.sort_index()
@@ -197,18 +194,18 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
instruments : dict
- the dict of instruments to be filtered
+ the dict of instruments to be filtered.
fstart : pd.Timestamp
- start time of filter
+ start time of filter.
fend : pd.Timestamp
- end time of filter
+ end time of filter.
- .. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time
+ .. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time.
Returns
----------
pd.Dataframe
- a series of {pd.Timestamp => bool}
+ a series of {pd.Timestamp => bool}.
"""
raise NotImplementedError("Subclass of SeriesDFilter must reimplement `getFilterSeries` method")
@@ -218,16 +215,16 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
instruments: dict
- input instruments to be filtered
+ input instruments to be filtered.
start_time: str
- start of the time range
+ start of the time range.
end_time: str
- end of the time range
+ end of the time range.
Returns
----------
dict
- filtered instruments, same structure as input instruments
+ filtered instruments, same structure as input instruments.
"""
lbound, ubound = self._getTimeBound(instruments)
start_time = pd.Timestamp(start_time or lbound)
@@ -275,7 +272,7 @@ class NameDFilter(SeriesDFilter):
params:
------
name_rule_re: str
- regular expression for the name rule
+ regular expression for the name rule.
"""
super(NameDFilter, self).__init__(fstart_time, fend_time)
self.name_rule_re = name_rule_re
@@ -328,13 +325,13 @@ class ExpressionDFilter(SeriesDFilter):
params:
------
fstart_time: str
- filter the feature starting from this time
+ filter the feature starting from this time.
fend_time: str
- filter the feature ending by this time
+ filter the feature ending by this time.
rule_expression: str
- an input expression for the rule
+ an input expression for the rule.
keep: bool
- whether to keep the instruments of which features don't exist in the filter time span
+ whether to keep the instruments of which features don't exist in the filter time span.
"""
super(ExpressionDFilter, self).__init__(fstart_time, fend_time)
self.rule_expression = rule_expression
diff --git a/qlib/data/ops.py b/qlib/data/ops.py
index cd7a541d3..e17c0e4e6 100644
--- a/qlib/data/ops.py
+++ b/qlib/data/ops.py
@@ -9,6 +9,8 @@ import sys
import numpy as np
import pandas as pd
+from scipy.stats import percentileofscore
+
from .base import Expression, ExpressionOps
from ..log import get_module_logger
@@ -687,6 +689,8 @@ class Rolling(ExpressionOps):
# isnull = series.isnull() # NOTE: isnull = NaN, inf is not null
if self.N == 0:
series = getattr(series.expanding(min_periods=1), self.func)()
+ elif 0 < self.N < 1:
+ series = series.ewm(alpha=self.N, min_periods=1).mean()
else:
series = getattr(series.rolling(self.N, min_periods=1), self.func)()
# series.iloc[:self.N-1] = np.nan
@@ -696,6 +700,8 @@ class Rolling(ExpressionOps):
def get_longest_back_rolling(self):
if self.N == 0:
return np.inf
+ if 0 < self.N < 1:
+ return int(np.log(1e-6) / np.log(1 - self.N)) # (1 - N)**window == 1e-6
return self.feature.get_longest_back_rolling() + self.N - 1
def get_extended_window_size(self):
@@ -704,6 +710,11 @@ class Rolling(ExpressionOps):
# remove such support for N == 0?
get_module_logger(self.__class__.__name__).warning("The Rolling(ATTR, 0) will not be accurately calculated")
return self.feature.get_extended_window_size()
+ elif 0 < self.N < 1:
+ lft_etd, rght_etd = self.feature.get_extended_window_size()
+ size = int(np.log(1e-6) / np.log(1 - self.N))
+ lft_etd = max(lft_etd + size - 1, lft_etd)
+ return lft_etd, rght_etd
else:
lft_etd, rght_etd = self.feature.get_extended_window_size()
lft_etd = max(lft_etd + self.N - 1, lft_etd)
@@ -1091,7 +1102,7 @@ class Rank(Rolling):
x1 = x[~np.isnan(x)]
if x1.shape[0] == 0:
return np.nan
- return (x1.argsort()[-1] + 1) / len(x1)
+ return percentileofscore(x1, x1[-1]) / len(x1)
if self.N == 0:
series = series.expanding(min_periods=1).apply(rank, raw=True)
@@ -1277,7 +1288,7 @@ class EMA(Rolling):
----------
feature : Expression
feature instance
- N : int
+ N : int, float
rolling window size
Returns
@@ -1300,6 +1311,8 @@ class EMA(Rolling):
if self.N == 0:
series = series.expanding(min_periods=1).apply(exp_weighted_mean, raw=True)
+ elif 0 < self.N < 1:
+ series = series.ewm(alpha=self.N, min_periods=1).mean()
else:
series = series.ewm(span=self.N, min_periods=1).mean()
return series
diff --git a/qlib/log.py b/qlib/log.py
index bc87fc579..422a4c00b 100644
--- a/qlib/log.py
+++ b/qlib/log.py
@@ -2,12 +2,13 @@
# Licensed under the MIT License.
+import logging
+import logging.handlers
import os
import re
-import logging
-from time import time
-import logging.handlers
from logging import config as logging_config
+from time import time
+from contextlib import contextmanager
from .config import C
@@ -77,7 +78,29 @@ class TimeInspector(object):
Info that will be log into stdout.
"""
cost_time = time() - cls.time_marks.pop()
- cls.timer_logger.info("Time cost: {0:.5f} | {1}".format(cost_time, info))
+ cls.timer_logger.info("Time cost: {0:.3f}s | {1}".format(cost_time, info))
+
+ @classmethod
+ @contextmanager
+ def logt(cls, name="", show_start=False):
+ """logt.
+ Log the time of the inside code
+
+ Parameters
+ ----------
+ name :
+ name
+ show_start :
+ show_start
+ """
+ if show_start:
+ cls.timer_logger.info(f"{name} Begin")
+ cls.set_time_mark()
+ try:
+ yield None
+ finally:
+ pass
+ cls.log_cost_time(info=f"{name} Done")
def set_log_with_config(log_config: dict):
diff --git a/qlib/model/__init__.py b/qlib/model/__init__.py
new file mode 100644
index 000000000..c639b57f5
--- /dev/null
+++ b/qlib/model/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import warnings
+
+from .base import Model
diff --git a/qlib/model/base.py b/qlib/model/base.py
new file mode 100644
index 000000000..c9bef1152
--- /dev/null
+++ b/qlib/model/base.py
@@ -0,0 +1,81 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+import abc
+from ..utils.serial import Serializable
+from ..data.dataset import Dataset
+
+
+class BaseModel(Serializable, metaclass=abc.ABCMeta):
+ """Modeling things"""
+
+ @abc.abstractmethod
+ def predict(self, *args, **kwargs) -> object:
+ """ Make predictions after modeling things """
+ pass
+
+ def __call__(self, *args, **kwargs) -> object:
+ """ leverage Python syntactic sugar to make the models' behaviors like functions """
+ return self.predict(*args, **kwargs)
+
+
+class Model(BaseModel):
+ """Learnable Models"""
+
+ def fit(self, dataset: Dataset):
+ """
+ Learn model from the base model
+
+ .. note::
+
+ The the attribute names of learned model should `not` start with '_'. So that the model could be
+ dumped to disk.
+
+ Parameters
+ ----------
+ dataset : Dataset
+ dataset will generate the processed data from model training.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def predict(self, dataset: Dataset) -> object:
+ """give prediction given Dataset
+
+ Parameters
+ ----------
+ dataset : Dataset
+ dataset will generate the processed dataset from model training.
+ """
+ raise NotImplementedError()
+
+
+class ModelFT(Model):
+ """Model (F)ine(t)unable"""
+
+ @abc.abstractmethod
+ def finetune(self, dataset: Dataset):
+ """finetune model based given dataset
+
+ A typical use case of finetuning model with qlib.workflow.R
+
+ .. code-block:: python
+
+ # start exp to train init model
+ with R.start(experiment_name="init models"):
+ model.fit(dataset)
+ R.save_objects(init_model=model)
+ rid = R.get_recorder().id
+
+ # Finetune model based on previous trained model
+ with R.start(experiment_name="finetune model"):
+ recorder = R.get_recorder(rid, experiment_name="init models")
+ model = recorder.load_object("init_model")
+ model.finetune(dataset, num_boost_round=10)
+
+
+ Parameters
+ ----------
+ dataset : Dataset
+ dataset will generate the processed dataset from model training.
+ """
+ raise NotImplementedError()
diff --git a/qlib/model/riskmodel.py b/qlib/model/riskmodel.py
new file mode 100644
index 000000000..07a1e0c9f
--- /dev/null
+++ b/qlib/model/riskmodel.py
@@ -0,0 +1,467 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import warnings
+import numpy as np
+import pandas as pd
+
+from typing import Union
+
+from qlib.model.base import BaseModel
+
+
+class RiskModel(BaseModel):
+ """Risk Model
+
+ A risk model is used to estimate the covariance matrix of stock returns.
+ """
+
+ MASK_NAN = "mask"
+ FILL_NAN = "fill"
+ IGNORE_NAN = "ignore"
+
+ def __init__(self, nan_option: str = "ignore", assume_centered: bool = False, scale_return: bool = True):
+ """
+ Args:
+ nan_option (str): nan handling option (`ignore`/`mask`/`fill`).
+ assume_centered (bool): whether the data is assumed to be centered.
+ scale_return (bool): whether scale returns as percentage.
+ """
+ # nan
+ assert nan_option in [
+ self.MASK_NAN,
+ self.FILL_NAN,
+ self.IGNORE_NAN,
+ ], f"`nan_option={nan_option}` is not supported"
+ self.nan_option = nan_option
+
+ self.assume_centered = assume_centered
+ self.scale_return = scale_return
+
+ def predict(
+ self, X: Union[pd.Series, pd.DataFrame, np.ndarray], return_corr: bool = False, is_price: bool = True
+ ) -> Union[pd.DataFrame, np.ndarray]:
+ """
+ Args:
+ X (pd.Series, pd.DataFrame or np.ndarray): data from which to estimate the covariance,
+ with variables as columns and observations as rows.
+ return_corr (bool): whether return the correlation matrix.
+ is_price (bool): whether `X` contains price (if not assume stock returns).
+
+ Returns:
+ pd.DataFrame or np.ndarray: estimated covariance (or correlation).
+ """
+ # transform input into 2D array
+ if not isinstance(X, (pd.Series, pd.DataFrame)):
+ columns = None
+ else:
+ if isinstance(X.index, pd.MultiIndex):
+ if isinstance(X, pd.DataFrame):
+ X = X.iloc[:, 0].unstack(level="instrument") # always use the first column
+ else:
+ X = X.unstack(level="instrument")
+ else:
+ # X is 2D DataFrame
+ pass
+ columns = X.columns # will be used to restore dataframe
+ X = X.values
+
+ # calculate pct_change
+ if is_price:
+ X = X[1:] / X[:-1] - 1 # NOTE: resulting `n - 1` rows
+
+ # scale return
+ if self.scale_return:
+ X *= 100
+
+ # handle nan and centered
+ X = self._preprocess(X)
+
+ # estimate covariance
+ S = self._predict(X)
+
+ # return correlation if needed
+ if return_corr:
+ vola = np.sqrt(np.diag(S))
+ corr = S / np.outer(vola, vola)
+ if columns is None:
+ return corr
+ return pd.DataFrame(corr, index=columns, columns=columns)
+
+ # return covariance
+ if columns is None:
+ return S
+ return pd.DataFrame(S, index=columns, columns=columns)
+
+ def _predict(self, X: np.ndarray) -> np.ndarray:
+ """covariance estimation implementation
+
+ This method should be overridden by child classes.
+
+ By default, this method implements the empirical covariance estimation.
+
+ Args:
+ X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).
+
+ Returns:
+ np.ndarray: covariance matrix.
+ """
+ xTx = np.asarray(X.T.dot(X))
+ N = len(X)
+ if isinstance(X, np.ma.MaskedArray):
+ M = 1 - X.mask
+ N = M.T.dot(M) # each pair has distinct number of samples
+ return xTx / N
+
+ def _preprocess(self, X: np.ndarray) -> Union[np.ndarray, np.ma.MaskedArray]:
+ """handle nan and centerize data
+
+ Note:
+ if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`.
+ """
+ # handle nan
+ if self.nan_option == self.FILL_NAN:
+ X = np.nan_to_num(X)
+ elif self.nan_option == self.MASK_NAN:
+ X = np.ma.masked_invalid(X)
+ # centerize
+ if not self.assume_centered:
+ X = X - np.nanmean(X, axis=0)
+ return X
+
+
+class ShrinkCovEstimator(RiskModel):
+ """Shrinkage Covariance Estimator
+
+ This estimator will shrink the sample covariance matrix towards
+ an identify matrix:
+ S_hat = (1 - alpha) * S + alpha * F
+ where `alpha` is the shrink parameter and `F` is the shrinking target.
+
+ The following shrinking parameters (`alpha`) are supported:
+ - `lw` [1][2][3]: use Ledoit-Wolf shrinking parameter.
+ - `oas` [4]: use Oracle Approximating Shrinkage shrinking parameter.
+ - float: directly specify the shrink parameter, should be between [0, 1].
+
+ The following shrinking targets (`F`) are supported:
+ - `const_var` [1][4][5]: assume stocks have the same constant variance and zero correlation.
+ - `const_corr` [2][6]: assume stocks have different variance but equal correlation.
+ - `single_factor` [3][7]: assume single factor model as the shrinking target.
+ - np.ndarray: provide the shrinking targets directly.
+
+ Note:
+ - The optimal shrinking parameter depends on the selection of the shrinking target.
+ Currently, `oas` is not supported for `const_corr` and `single_factor`.
+ - Remember to set `nan_option` to `fill` or `mask` if your data has missing values.
+
+ References:
+ [1] Ledoit, O., & Wolf, M. (2004). A well-conditioned estimator for large-dimensional covariance matrices.
+ Journal of Multivariate Analysis, 88(2), 365–411. https://doi.org/10.1016/S0047-259X(03)00096-4
+ [2] Ledoit, O., & Wolf, M. (2004). Honey, I shrunk the sample covariance matrix.
+ Journal of Portfolio Management, 30(4), 1–22. https://doi.org/10.3905/jpm.2004.110
+ [3] Ledoit, O., & Wolf, M. (2003). Improved estimation of the covariance matrix of stock returns
+ with an application to portfolio selection.
+ Journal of Empirical Finance, 10(5), 603–621. https://doi.org/10.1016/S0927-5398(03)00007-0
+ [4] Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O. (2010). Shrinkage algorithms for MMSE covariance estimation.
+ IEEE Transactions on Signal Processing, 58(10), 5016–5029. https://doi.org/10.1109/TSP.2010.2053029
+ [5] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-0000-00007f64e5b9/cov1para.m.zip
+ [6] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-ffff-ffffde5e2d4e/covCor.m.zip
+ [7] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-0000-0000648dfc98/covMarket.m.zip
+ """
+
+ SHR_LW = "lw"
+ SHR_OAS = "oas"
+
+ TGT_CONST_VAR = "const_var"
+ TGT_CONST_CORR = "const_corr"
+ TGT_SINGLE_FACTOR = "single_factor"
+
+ def __init__(self, alpha: Union[str, float] = 0.0, target: Union[str, np.ndarray] = "const_var", **kwargs):
+ """
+ Args:
+ alpha (str or float): shrinking parameter or estimator (`lw`/`oas`)
+ target (str or np.ndarray): shrinking target (`const_var`/`const_corr`/`single_factor`)
+ kwargs: see `RiskModel` for more information
+ """
+ super().__init__(**kwargs)
+
+ # alpha
+ if isinstance(alpha, str):
+ assert alpha in [self.SHR_LW, self.SHR_OAS], f"shrinking method `{alpha}` is not supported"
+ elif isinstance(alpha, (float, np.floating)):
+ assert 0 <= alpha <= 1, "alpha should be between [0, 1]"
+ else:
+ raise TypeError("invalid argument type for `alpha`")
+ self.alpha = alpha
+
+ # target
+ if isinstance(target, str):
+ assert target in [
+ self.TGT_CONST_VAR,
+ self.TGT_CONST_CORR,
+ self.TGT_SINGLE_FACTOR,
+ ], f"shrinking target `{target} is not supported"
+ elif isinstance(target, np.ndarray):
+ pass
+ else:
+ raise TypeError("invalid argument type for `target`")
+ if alpha == self.SHR_OAS and target != self.TGT_CONST_VAR:
+ raise NotImplementedError("currently `oas` can only support `const_var` as target")
+ self.target = target
+
+ def _predict(self, X: np.ndarray) -> np.ndarray:
+ # sample covariance
+ S = super()._predict(X)
+
+ # shrinking target
+ F = self._get_shrink_target(X, S)
+
+ # get shrinking parameter
+ alpha = self._get_shrink_param(X, S, F)
+
+ # shrink covariance
+ if alpha > 0:
+ S *= 1 - alpha
+ F *= alpha
+ S += F
+
+ return S
+
+ def _get_shrink_target(self, X: np.ndarray, S: np.ndarray) -> np.ndarray:
+ """get shrinking target `F`"""
+ if self.target == self.TGT_CONST_VAR:
+ return self._get_shrink_target_const_var(X, S)
+ if self.target == self.TGT_CONST_CORR:
+ return self._get_shrink_target_const_corr(X, S)
+ if self.target == self.TGT_SINGLE_FACTOR:
+ return self._get_shrink_target_single_factor(X, S)
+ return self.target
+
+ def _get_shrink_target_const_var(self, X: np.ndarray, S: np.ndarray) -> np.ndarray:
+ """get shrinking target with constant variance
+
+ This target assumes zero pair-wise correlation and constant variance.
+ The constant variance is estimated by averaging all sample's variances.
+ """
+ n = len(S)
+ F = np.eye(n)
+ np.fill_diagonal(F, np.mean(np.diag(S)))
+ return F
+
+ def _get_shrink_target_const_corr(self, X: np.ndarray, S: np.ndarray) -> np.ndarray:
+ """get shrinking target with constant correlation
+
+ This target assumes constant pair-wise correlation but keep the sample variance.
+ The constant correlation is estimated by averaging all pairwise correlations.
+ """
+ n = len(S)
+ var = np.diag(S)
+ sqrt_var = np.sqrt(var)
+ covar = np.outer(sqrt_var, sqrt_var)
+ r_bar = (np.sum(S / covar) - n) / (n * (n - 1))
+ F = r_bar * covar
+ np.fill_diagonal(F, var)
+ return F
+
+ def _get_shrink_target_single_factor(self, X: np.ndarray, S: np.ndarray) -> np.ndarray:
+ """get shrinking target with single factor model"""
+ X_mkt = np.nanmean(X, axis=1)
+ cov_mkt = np.asarray(X.T.dot(X_mkt) / len(X))
+ var_mkt = np.asarray(X_mkt.dot(X_mkt) / len(X))
+ F = np.outer(cov_mkt, cov_mkt) / var_mkt
+ np.fill_diagonal(F, np.diag(S))
+ return F
+
+ def _get_shrink_param(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:
+ """get shrinking parameter `alpha`
+
+ Note:
+ The Ledoit-Wolf shrinking parameter estimator consists of three different methods.
+ """
+ if self.alpha == self.SHR_OAS:
+ return self._get_shrink_param_oas(X, S, F)
+ elif self.alpha == self.SHR_LW:
+ if self.target == self.TGT_CONST_VAR:
+ return self._get_shrink_param_lw_const_var(X, S, F)
+ if self.target == self.TGT_CONST_CORR:
+ return self._get_shrink_param_lw_const_corr(X, S, F)
+ if self.target == self.TGT_SINGLE_FACTOR:
+ return self._get_shrink_param_lw_single_factor(X, S, F)
+ return self.alpha
+
+ def _get_shrink_param_oas(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:
+ """Oracle Approximating Shrinkage Estimator
+
+ This method uses the following formula to estimate the `alpha`
+ parameter for the shrink covariance estimator:
+ A = (1 - 2 / p) * trace(S^2) + trace^2(S)
+ B = (n + 1 - 2 / p) * (trace(S^2) - trace^2(S) / p)
+ alpha = A / B
+ where `n`, `p` are the dim of observations and variables respectively.
+ """
+ trS2 = np.sum(S ** 2)
+ tr2S = np.trace(S) ** 2
+
+ n, p = X.shape
+
+ A = (1 - 2 / p) * (trS2 + tr2S)
+ B = (n + 1 - 2 / p) * (trS2 + tr2S / p)
+ alpha = A / B
+
+ return alpha
+
+ def _get_shrink_param_lw_const_var(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:
+ """Ledoit-Wolf Shrinkage Estimator (Constant Variance)
+
+ This method shrinks the covariance matrix towards the constand variance target.
+ """
+ t, n = X.shape
+
+ y = X ** 2
+ phi = np.sum(y.T.dot(y) / t - S ** 2)
+
+ gamma = np.linalg.norm(S - F, "fro") ** 2
+
+ kappa = phi / gamma
+ alpha = max(0, min(1, kappa / t))
+
+ return alpha
+
+ def _get_shrink_param_lw_const_corr(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:
+ """Ledoit-Wolf Shrinkage Estimator (Constant Correlation)
+
+ This method shrinks the covariance matrix towards the constand correlation target.
+ """
+ t, n = X.shape
+
+ var = np.diag(S)
+ sqrt_var = np.sqrt(var)
+ r_bar = (np.sum(S / np.outer(sqrt_var, sqrt_var)) - n) / (n * (n - 1))
+
+ y = X ** 2
+ phi_mat = y.T.dot(y) / t - S ** 2
+ phi = np.sum(phi_mat)
+
+ theta_mat = (X ** 3).T.dot(X) / t - var[:, None] * S
+ np.fill_diagonal(theta_mat, 0)
+ rho = np.sum(np.diag(phi_mat)) + r_bar * np.sum(np.outer(1 / sqrt_var, sqrt_var) * theta_mat)
+
+ gamma = np.linalg.norm(S - F, "fro") ** 2
+
+ kappa = (phi - rho) / gamma
+ alpha = max(0, min(1, kappa / t))
+
+ return alpha
+
+ def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float:
+ """Ledoit-Wolf Shrinkage Estimator (Single Factor Model)
+
+ This method shrinks the covariance matrix towards the single factor model target.
+ """
+ t, n = X.shape
+
+ X_mkt = np.nanmean(X, axis=1)
+ cov_mkt = np.asarray(X.T.dot(X_mkt) / len(X))
+ var_mkt = np.asarray(X_mkt.dot(X_mkt) / len(X))
+
+ y = X ** 2
+ phi = np.sum(y.T.dot(y)) / t - np.sum(S ** 2)
+
+ rdiag = np.sum(y ** 2) / t - np.sum(np.diag(S) ** 2)
+ z = X * X_mkt[:, None]
+ v1 = y.T.dot(z) / t - cov_mkt[:, None] * S
+ roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt
+ v3 = z.T.dot(z) / t - var_mkt * S
+ roff3 = (
+ np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt ** 2 - np.sum(np.diag(v3) * cov_mkt ** 2) / var_mkt ** 2
+ )
+ roff = 2 * roff1 - roff3
+ rho = rdiag + roff
+
+ gamma = np.linalg.norm(S - F, "fro") ** 2
+
+ kappa = (phi - rho) / gamma
+ alpha = max(0, min(1, kappa / t))
+
+ return alpha
+
+
+class POETCovEstimator(RiskModel):
+ """Principal Orthogonal Complement Thresholding Estimator (POET)
+
+ Reference:
+ [1] Fan, J., Liao, Y., & Mincheva, M. (2013). Large covariance estimation by thresholding principal orthogonal complements.
+ Journal of the Royal Statistical Society. Series B: Statistical Methodology, 75(4), 603–680. https://doi.org/10.1111/rssb.12016
+ [2] http://econweb.rutgers.edu/yl1114/papers/poet/POET.m
+ """
+
+ THRESH_SOFT = "soft"
+ THRESH_HARD = "hard"
+ THRESH_SCAD = "scad"
+
+ def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = "soft", **kwargs):
+ """
+ Args:
+ num_factors (int): number of factors (if set to zero, no factor model will be used).
+ thresh (float): the positive constant for thresholding.
+ thresh_method (str): thresholding method, which can be
+ - 'soft': soft thresholding.
+ - 'hard': hard thresholding.
+ - 'scad': scad thresholding.
+ kwargs: see `RiskModel` for more information.
+ """
+ super().__init__(**kwargs)
+
+ assert num_factors >= 0, "`num_factors` requires a positive integer"
+ self.num_factors = num_factors
+
+ assert thresh >= 0, "`thresh` requires a positive float number"
+ self.thresh = thresh
+
+ assert thresh_method in [
+ self.THRESH_HARD,
+ self.THRESH_SOFT,
+ self.THRESH_SCAD,
+ ], "`thresh_method` should be `soft`/`hard`/`scad`"
+ self.thresh_method = thresh_method
+
+ def _predict(self, X: np.ndarray) -> np.ndarray:
+
+ Y = X.T # NOTE: to match POET's implementation
+ p, n = Y.shape
+
+ if self.num_factors > 0:
+ Dd, V = np.linalg.eig(Y.T.dot(Y))
+ V = V[:, np.argsort(Dd)]
+ F = V[:, -self.num_factors :][:, ::-1] * np.sqrt(n)
+ LamPCA = Y.dot(F) / n
+ uhat = np.asarray(Y - LamPCA.dot(F.T))
+ Lowrank = np.asarray(LamPCA.dot(LamPCA.T))
+ rate = 1 / np.sqrt(p) + np.sqrt(np.log(p) / n)
+ else:
+ uhat = np.asarray(Y)
+ rate = np.sqrt(np.log(p) / n)
+ Lowrank = 0
+
+ lamb = rate * self.thresh
+ SuPCA = uhat.dot(uhat.T) / n
+ SuDiag = np.diag(np.diag(SuPCA))
+ R = np.linalg.inv(SuDiag ** 0.5).dot(SuPCA).dot(np.linalg.inv(SuDiag ** 0.5))
+
+ if self.thresh_method == self.THRESH_HARD:
+ M = R * (np.abs(R) > lamb)
+ elif self.thresh_method == self.THRESH_SOFT:
+ res = np.abs(R) - lamb
+ res = (res + np.abs(res)) / 2
+ M = np.sign(R) * res
+ else:
+ M1 = (np.abs(R) < 2 * lamb) * np.sign(R) * (np.abs(R) - lamb) * (np.abs(R) > lamb)
+ M2 = (np.abs(R) < 3.7 * lamb) * (np.abs(R) >= 2 * lamb) * (2.7 * R - 3.7 * np.sign(R) * lamb) / 1.7
+ M3 = (np.abs(R) >= 3.7 * lamb) * R
+ M = M1 + M2 + M3
+
+ Rthresh = M - np.diag(np.diag(M)) + np.eye(p)
+ SigmaU = (SuDiag ** 0.5).dot(Rthresh).dot(SuDiag ** 0.5)
+ SigmaY = SigmaU + Lowrank
+
+ return SigmaY
diff --git a/qlib/model/task.py b/qlib/model/task.py
new file mode 100644
index 000000000..f29f513a4
--- /dev/null
+++ b/qlib/model/task.py
@@ -0,0 +1,27 @@
+import abc
+import typing
+
+
+class TaskGen(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def __call__(self, *args, **kwargs) -> typing.List[dict]:
+ """
+ generate
+
+ Parameters
+ ----------
+ args, kwargs:
+ The info for generating tasks
+ Example 1):
+ input: a specific task template
+ output: rolling version of the tasks
+ Example 2):
+ input: a specific task template
+ output: a set of tasks with different losses
+
+ Returns
+ -------
+ typing.List[dict]:
+ A list of tasks
+ """
+ pass
diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py
new file mode 100644
index 000000000..0ef062021
--- /dev/null
+++ b/qlib/model/trainer.py
@@ -0,0 +1,41 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from qlib.utils import init_instance_by_config, flatten_dict
+from qlib.workflow import R
+from qlib.workflow.record_temp import SignalRecord
+
+
+def task_train(config: dict, experiment_name):
+ """
+ task based training
+
+ Parameters
+ ----------
+ config : dict
+ A dict describing the training process
+ """
+
+ # model initiaiton
+ model = init_instance_by_config(config.get("task")["model"])
+ dataset = init_instance_by_config(config.get("task")["dataset"])
+
+ # start exp
+ with R.start(experiment_name=experiment_name):
+ # train model
+ R.log_params(**flatten_dict(config.get("task")))
+ model.fit(dataset)
+ recorder = R.get_recorder()
+
+ # generate records: prediction, backtest, and analysis
+ for record in config.get("task")["record"]:
+ if record["class"] == SignalRecord.__name__:
+ srconf = {"model": model, "dataset": dataset, "recorder": recorder}
+ record["kwargs"].update(srconf)
+ sr = init_instance_by_config(record)
+ sr.generate()
+ else:
+ rconf = {"recorder": recorder}
+ record["kwargs"].update(rconf)
+ ar = init_instance_by_config(record)
+ ar.generate()
diff --git a/qlib/portfolio/__init__.py b/qlib/portfolio/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/qlib/portfolio/optimizer.py b/qlib/portfolio/optimizer.py
new file mode 100644
index 000000000..534a66e2d
--- /dev/null
+++ b/qlib/portfolio/optimizer.py
@@ -0,0 +1,258 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import warnings
+import numpy as np
+import pandas as pd
+import scipy.optimize as so
+
+from typing import Optional, Union, Callable, List
+
+
+class PortfolioOptimizer(object):
+ """Portfolio Optimizer
+
+ The following optimization algorithms are supported:
+ - `gmv`: Global Minimum Variance Portfolio
+ - `mvo`: Mean Variance Optimized Portfolio
+ - `rp`: Risk Parity
+ - `inv`: Inverse Volatility
+
+ Note:
+ This optimizer always assumes full investment and no-shorting.
+ """
+
+ OPT_GMV = "gmv"
+ OPT_MVO = "mvo"
+ OPT_RP = "rp"
+ OPT_INV = "inv"
+
+ def __init__(
+ self,
+ method: str = "inv",
+ lamb: float = 0,
+ delta: float = 0,
+ alpha: float = 0.0,
+ scale_alpha: bool = True,
+ tol: float = 1e-8,
+ ):
+ """
+ Args:
+ method (str): portfolio optimization method
+ lamb (float): risk aversion parameter (larger `lamb` means more focus on return)
+ delta (float): turnover rate limit
+ alpha (float): l2 norm regularizer
+ tol (float): tolerance for optimization termination
+ """
+ assert method in [self.OPT_GMV, self.OPT_MVO, self.OPT_RP, self.OPT_INV], f"method `{method}` is not supported"
+ self.method = method
+
+ assert lamb >= 0, f"risk aversion parameter `lamb` should be positive"
+ self.lamb = lamb
+
+ assert delta >= 0, f"turnover limit `delta` should be positive"
+ self.delta = delta
+
+ assert alpha >= 0, f"l2 norm regularizer `alpha` should be positive"
+ self.alpha = alpha
+
+ self.tol = tol
+
+ def __call__(
+ self,
+ S: Union[np.ndarray, pd.DataFrame],
+ u: Optional[Union[np.ndarray, pd.Series]] = None,
+ w0: Optional[Union[np.ndarray, pd.Series]] = None,
+ ) -> Union[np.ndarray, pd.Series]:
+ """
+ Args:
+ S (np.ndarray or pd.DataFrame): covariance matrix
+ u (np.ndarray or pd.Series): expected returns (a.k.a., alpha)
+ w0 (np.ndarray or pd.Series): initial weights (for turnover control)
+
+ Returns:
+ np.ndarray or pd.Series: optimized portfolio allocation
+ """
+ # transform dataframe into array
+ index = None
+ if isinstance(S, pd.DataFrame):
+ index = S.index
+ S = S.values
+
+ # transform alpha
+ if u is not None:
+ assert len(u) == len(S), "`u` has mismatched shape"
+ if isinstance(u, pd.Series):
+ assert all(u.index == index), "`u` has mismatched index"
+ u = u.values
+
+ # transform initial weights
+ if w0 is not None:
+ assert len(w0) == len(S), "`w0` has mismatched shape"
+ if isinstance(w0, pd.Series):
+ assert all(w0.index == index), "`w0` has mismatched index"
+ w0 = w0.values
+
+ # scale alpha to match volatility
+ if u is not None:
+ u = u / u.std()
+ u *= np.mean(np.diag(S)) ** 0.5
+
+ # optimize
+ w = self._optimize(S, u, w0)
+
+ # restore index if needed
+ if index is not None:
+ w = pd.Series(w, index=index)
+
+ return w
+
+ def _optimize(self, S: np.ndarray, u: Optional[np.ndarray] = None, w0: Optional[np.ndarray] = None) -> np.ndarray:
+
+ # inverse volatility
+ if self.method == self.OPT_INV:
+ if u is not None:
+ warnings.warn("`u` is set but will not be used for `inv` portfolio")
+ if w0 is not None:
+ warnings.warn("`w0` is set but will not be used for `inv` portfolio")
+ return self._optimize_inv(S)
+
+ # global minimum variance
+ if self.method == self.OPT_GMV:
+ if u is not None:
+ warnings.warn("`u` is set but will not be used for `gmv` portfolio")
+ return self._optimize_gmv(S, w0)
+
+ # mean-variance
+ if self.method == self.OPT_MVO:
+ return self._optimize_mvo(S, u, w0)
+
+ # risk parity
+ if self.method == self.OPT_RP:
+ if u is not None:
+ warnings.warn("`u` is set but will not be used for `rp` portfolio")
+ return self._optimize_rp(S, w0)
+
+ def _optimize_inv(self, S: np.ndarray) -> np.ndarray:
+ """Inverse volatility"""
+ vola = np.diag(S) ** 0.5
+ w = 1 / vola
+ w /= w.sum()
+ return w
+
+ def _optimize_gmv(self, S: np.ndarray, w0: Optional[np.ndarray] = None) -> np.ndarray:
+ """optimize global minimum variance portfolio
+
+ This method solves the following optimization problem
+ min_w w' S w
+ s.t. w >= 0, sum(w) == 1
+ where `S` is the covariance matrix.
+ """
+ return self._solve(len(S), self._get_objective_gmv(S), *self._get_constrains(w0))
+
+ def _optimize_mvo(
+ self, S: np.ndarray, u: Optional[np.ndarray] = None, w0: Optional[np.ndarray] = None
+ ) -> np.ndarray:
+ """optimize mean-variance portfolio
+
+ This method solves the following optimization problem
+ min_w - w' u + lamb * w' S w
+ s.t. w >= 0, sum(w) == 1
+ where `S` is the covariance matrix, `u` is the expected returns,
+ and `lamb` is the risk aversion parameter.
+ """
+ return self._solve(len(S), self._get_objective_mvo(S, u), *self._get_constrains(w0))
+
+ def _optimize_rp(self, S: np.ndarray, w0: Optional[np.ndarray] = None) -> np.ndarray:
+ """optimize risk parity portfolio
+
+ This method solves the following optimization problem
+ min_w sum_i [w_i - (w' S w) / ((S w)_i * N)]**2
+ s.t. w >= 0, sum(w) == 1
+ where `S` is the covariance matrix and `N` is the number of stocks.
+ """
+ return self._solve(len(S), self._get_objective_rp(S), *self._get_constrains(w0))
+
+ def _get_objective_gmv(self, S: np.ndarray) -> np.ndarray:
+ """global minimum variance optimization objective
+
+ Optimization objective
+ min_w w' S w
+ """
+
+ def func(x):
+ return x @ S @ x
+
+ return func
+
+ def _get_objective_mvo(self, S: np.ndarray, u: np.ndarray = None) -> np.ndarray:
+ """mean-variance optimization objective
+
+ Optimization objective
+ min_w - w' u + lamb * w' S w
+ """
+
+ def func(x):
+ risk = x @ S @ x
+ ret = x @ u
+ return -ret + self.lamb * risk
+
+ return func
+
+ def _get_objective_rp(self, S: np.ndarray) -> np.ndarray:
+ """risk-parity optimization objective
+
+ Optimization objective
+ min_w sum_i [w_i - (w' S w) / ((S w)_i * N)]**2
+ """
+
+ def func(x):
+ N = len(x)
+ Sx = S @ x
+ xSx = x @ Sx
+ return np.sum((x - xSx / Sx / N) ** 2)
+
+ return func
+
+ def _get_constrains(self, w0: Optional[np.ndarray] = None):
+ """optimization constraints
+
+ Defines the following constraints:
+ - no shorting and leverage: 0 <= w <= 1
+ - full investment: sum(w) == 1
+ - turnover constraint: |w - w0| <= delta
+ """
+
+ # no shorting and leverage
+ bounds = so.Bounds(0.0, 1.0)
+
+ # full investment constraint
+ cons = [{"type": "eq", "fun": lambda x: np.sum(x) - 1}] # == 0
+
+ # turnover constraint
+ if w0 is not None:
+ cons.append({"type": "ineq", "fun": lambda x: self.delta - np.sum(np.abs(x - w0))}) # >= 0
+
+ return bounds, cons
+
+ def _solve(self, n: int, obj: Callable, bounds: so.Bounds, cons: List) -> np.ndarray:
+ """solve optimization
+
+ Args:
+ n (int): number of parameters
+ obj (callable): optimization objective
+ bounds (Bounds): bounds of parameters
+ cons (list): optimization constraints
+ """
+ # add l2 regularization
+ wrapped_obj = obj
+ if self.alpha > 0:
+ wrapped_obj = lambda x: obj(x) + self.alpha * np.sum(np.square(x))
+
+ # solve
+ x0 = np.ones(n) / n # init results
+ sol = so.minimize(wrapped_obj, x0, bounds=bounds, constraints=cons, tol=self.tol)
+ if not sol.success:
+ warnings.warn(f"optimization not success ({sol.status})")
+
+ return sol.x
diff --git a/qlib/utils.py b/qlib/utils/__init__.py
similarity index 75%
rename from qlib/utils.py
rename to qlib/utils/__init__.py
index f45b171de..5b313a0ef 100644
--- a/qlib/utils.py
+++ b/qlib/utils/__init__.py
@@ -20,12 +20,14 @@ import requests
import tempfile
import importlib
import contextlib
+import collections
import numpy as np
import pandas as pd
from pathlib import Path
+from typing import Union, Tuple
-from .config import C
-from .log import get_module_logger
+from ..config import C
+from ..log import get_module_logger
log = get_module_logger("utils")
@@ -43,7 +45,7 @@ def read_bin(file_path, start_index, end_index):
ref_start_index = int(np.frombuffer(f.read(4), dtype=" end_index:
- return pd.Series()
+ return pd.Series(dtype=np.float32)
# calculate offset
f.seek(4 * (si - ref_start_index) + 4)
# read nbytes
@@ -164,6 +166,77 @@ def get_module_by_module_path(module_path):
return module
+def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
+ """
+ extract class and kwargs from config info
+
+ Parameters
+ ----------
+ config : [dict, str]
+ similar to config
+
+ module : Python module
+ It should be a python module to load the class type
+
+ Returns
+ -------
+ (type, dict):
+ the class object and it's arguments.
+ """
+ if isinstance(config, dict):
+ # raise AttributeError
+ klass = getattr(module, config["class"])
+ kwargs = config.get("kwargs", {})
+ elif isinstance(config, str):
+ klass = getattr(module, config)
+ kwargs = {}
+ else:
+ raise NotImplementedError(f"This type of input is not supported")
+ return klass, kwargs
+
+
+def init_instance_by_config(
+ config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs
+) -> object:
+ """
+ get initialized instance with config
+
+ Parameters
+ ----------
+ config : Union[str, dict, object]
+ dict example.
+ {
+ 'class': 'ClassName',
+ 'kwargs': dict, # It is optional. {} will be used if not given
+ 'model_path': path, # It is optional if module is given
+ }
+ str example.
+ "ClassName": getattr(module, config)() will be used.
+ object example:
+ instance of accept_types
+ module : Python module
+ Optional. It should be a python module.
+ NOTE: the "module_path" will be override by `module` arguments
+
+ accept_types: Union[type, Tuple[type]]
+ Optional. If the config is a instance of specific type, return the config directly.
+ This will be passed into the second parameter of isinstance.
+
+ Returns
+ -------
+ object:
+ An initialized object based on the config info
+ """
+ if isinstance(config, accept_types):
+ return config
+
+ if module is None:
+ module = get_module_by_module_path(config["module_path"])
+
+ klass, cls_kwargs = get_cls_kwargs(config, module)
+ return klass(**cls_kwargs, **kwargs)
+
+
def compare_dict_value(src_data: dict, dst_data: dict):
"""Compare dict value
@@ -377,7 +450,7 @@ def is_tradable_date(cur_date):
date : pandas.Timestamp
current date
"""
- from .data import D
+ from ..data import D
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
@@ -390,7 +463,7 @@ def get_date_range(trading_date, shift, future=False):
:param future: bool
:return:
"""
- from .data import D
+ from ..data import D
calendar = D.calendar(future=future)
if pd.to_datetime(trading_date) not in list(calendar):
@@ -445,7 +518,7 @@ def transform_end_date(end_date=None, freq="day"):
date : pandas.Timestamp
current date
"""
- from .data import D
+ from ..data import D
last_date = D.calendar(freq=freq)[-1]
if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)):
@@ -540,8 +613,101 @@ def exists_qlib_data(qlib_dir):
# check instruments
code_names = set(map(lambda x: x.name.lower(), features_dir.iterdir()))
_instrument = instruments_dir.joinpath("all.txt")
- miss_code = set(pd.read_csv(_instrument, sep="\t", header=None).loc[:, 0].apply(str.lower)) - set(code_names)
+ df = pd.read_csv(_instrument, sep="\t", names=["inst", "start_datetime", "end_datetime", "save_inst"])
+ df = df.iloc[:, [0, -1]].fillna(axis=1, method="ffill")
+ miss_code = set(df.iloc[:, -1].apply(str.lower)) - set(code_names)
if miss_code and any(map(lambda x: "sht" not in x, miss_code)):
return False
return True
+
+
+def lexsort_index(df: pd.DataFrame) -> pd.DataFrame:
+ """
+ make the df index lexsorted
+
+ df.sort_index() will take a lot of time even when `df.is_lexsorted() == True`
+ This function could avoid such case
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+
+ Returns
+ -------
+ pd.DataFrame:
+ sorted dataframe
+ """
+ if df.index.is_lexsorted():
+ return df
+ else:
+ return df.sort_index()
+
+
+def flatten_dict(d, parent_key="", sep="."):
+ """flatten_dict.
+ >>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
+ >>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
+
+ Parameters
+ ----------
+ d :
+ d
+ parent_key :
+ parent_key
+ sep :
+ sep
+ """
+ items = []
+ for k, v in d.items():
+ new_key = parent_key + sep + k if parent_key else k
+ if isinstance(v, collections.abc.MutableMapping):
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+#################### Wrapper #####################
+class Wrapper(object):
+ """Wrapper class for anything that needs to set up during qlib.init"""
+
+ def __init__(self):
+ self._provider = None
+
+ def register(self, provider):
+ self._provider = provider
+
+ def __getattr__(self, key):
+ if self._provider is None:
+ raise AttributeError("Please run qlib.init() first using qlib")
+ return getattr(self._provider, key)
+
+
+def register_wrapper(wrapper, cls_or_obj, module_path=None):
+ """register_wrapper
+
+ :param wrapper: A wrapper.
+ :param cls_or_obj: A class or class name or object instance.
+ """
+ if isinstance(cls_or_obj, str):
+ module = get_module_by_module_path(module_path)
+ cls_or_obj = getattr(module, cls_or_obj)
+ obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj
+ wrapper.register(obj)
+
+
+def load_dataset(path_or_obj):
+ """load dataset from multiple file formats"""
+ if isinstance(path_or_obj, pd.DataFrame):
+ return path_or_obj
+ if not os.path.exists(path_or_obj):
+ raise ValueError(f"file {path_or_obj} doesn't exist")
+ _, extension = os.path.splitext(path_or_obj)
+ if extension == ".h5":
+ return pd.read_hdf(path_or_obj)
+ elif extension == ".pkl":
+ return pd.read_pickle(path_or_obj)
+ elif extension == ".csv":
+ return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
+ raise ValueError(f"unsupported file type `{extension}`")
diff --git a/qlib/utils/objm.py b/qlib/utils/objm.py
new file mode 100644
index 000000000..eebd529c6
--- /dev/null
+++ b/qlib/utils/objm.py
@@ -0,0 +1,131 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+import os
+import pickle
+import tempfile
+from pathlib import Path
+
+from qlib.config import C
+
+
+class ObjManager:
+ def save_obj(self, obj: object, name: str):
+ """
+ save obj as name
+
+ Parameters
+ ----------
+ obj : object
+ object to be saved
+ name : str
+ name of the object
+ """
+ raise NotImplementedError(f"Please implement `save_obj`")
+
+ def save_objs(self, obj_name_l):
+ """
+ save objects
+
+ Parameters
+ ----------
+ obj_name_l : list of
+ """
+ raise NotImplementedError(f"Please implement the `save_objs` method")
+
+ def load_obj(self, name: str) -> object:
+ """
+ load object by name
+
+ Parameters
+ ----------
+ name : str
+ the name of the object
+
+ Returns
+ -------
+ object:
+ loaded object
+ """
+ raise NotImplementedError(f"Please implement the `load_obj` method")
+
+ def exists(self, name: str) -> bool:
+ """
+ if the object named `name` exists
+
+ Parameters
+ ----------
+ name : str
+ name of the objecT
+
+ Returns
+ -------
+ bool:
+ If the object exists
+ """
+ raise NotImplementedError(f"Please implement the `exists` method")
+
+ def list(self) -> list:
+ """
+ list the objects
+
+ Returns
+ -------
+ list:
+ the list of returned objects
+ """
+ raise NotImplementedError(f"Please implement the `list` method")
+
+ def remove(self, fname=None):
+ """remove.
+
+ Parameters
+ ----------
+ fname :
+ if file name is provided. specific file is removed
+ otherwise, The all the objects will be removed.
+ """
+ raise NotImplementedError(f"Please implement the `remove` method")
+
+
+class FileManager(ObjManager):
+ """
+ Use file system to manage objects
+ """
+
+ def __init__(self, path=None):
+ if path is None:
+ self.path = Path(self.create_path())
+ else:
+ self.path = Path(path).resolve()
+
+ def create_path(self) -> str:
+ try:
+ return tempfile.mkdtemp(prefix=str(C["file_manager_path"]) + os.sep)
+ except AttributeError:
+ raise NotImplementedError(f"If path is not given, the `create_path` function should be implemented")
+
+ def save_obj(self, obj, name):
+ with (self.path / name).open("wb") as f:
+ pickle.dump(obj, f)
+
+ def save_objs(self, obj_name_l):
+ for obj, name in obj_name_l:
+ self.save_obj(obj, name)
+
+ def load_obj(self, name):
+ with (self.path / name).open("rb") as f:
+ return pickle.load(f)
+
+ def exists(self, name):
+ return (self.path / name).exists()
+
+ def list(self):
+ return list(self.path.iterdir())
+
+ def remove(self, fname=None):
+ if fname is None:
+ for fp in self.path.glob("*"):
+ fp.unlink()
+ self.path.rmdir()
+ else:
+ (self.path / fname).unlink()
diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py
new file mode 100644
index 000000000..a640b04ea
--- /dev/null
+++ b/qlib/utils/paral.py
@@ -0,0 +1,39 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from joblib import Parallel, delayed
+import pandas as pd
+
+
+def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_rule="M", n_jobs=-1, skip_group=False):
+ """datetime_groupby_apply
+ This function will apply the `apply_func` on the datetime level index.
+
+ Parameters
+ ----------
+ df :
+ DataFrame for processing
+ apply_func :
+ apply_func for processing the data
+ axis :
+ which axis is the datetime level located
+ level :
+ which level is the datetime level
+ resample_rule :
+ How to resample the data to calculating parallel
+ n_jobs :
+ n_jobs for joblib
+ Returns:
+ pd.DataFrame
+ """
+
+ def _naive_group_apply(df):
+ return df.groupby(axis=axis, level=level).apply(apply_func)
+
+ if n_jobs != 1:
+ dfs = Parallel(n_jobs=n_jobs)(
+ delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, axis=axis, level=level)
+ )
+ return pd.concat(dfs, axis=axis).sort_index()
+ else:
+ return _naive_group_apply(df)
diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py
new file mode 100644
index 000000000..2d22434ac
--- /dev/null
+++ b/qlib/utils/serial.py
@@ -0,0 +1,60 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from pathlib import Path
+import pickle
+
+
+class Serializable:
+ """
+ Serializable behaves like pickle.
+ But it only saves the state whose name **does not** start with `_`
+ """
+
+ def __init__(self):
+ self._dump_all = False
+ self._exclude = []
+
+ def __getstate__(self) -> dict:
+ return {
+ k: v for k, v in self.__dict__.items() if k not in self.exclude and (self.dump_all or not k.startswith("_"))
+ }
+
+ def __setstate__(self, state: dict):
+ self.__dict__.update(state)
+
+ @property
+ def dump_all(self):
+ """
+ will the object dump all object
+
+ Parameters
+ ----------
+ self : [TODO:type]
+ [TODO:description]
+ """
+ return getattr(self, "_dump_all", False)
+
+ @property
+ def exclude(self):
+ """
+ What attribute will be dumped
+
+ Parameters
+ ----------
+ self : [TODO:type]
+ [TODO:description]
+ """
+ return getattr(self, "_exclude", [])
+
+ def config(self, dump_all: bool = None, exclude: list = None):
+ if dump_all is not None:
+ self._dump_all = dump_all
+
+ if exclude is not None:
+ self._exclude = exclude
+
+ def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
+ self.config(dump_all=dump_all, exclude=exclude)
+ with Path(path).open("wb") as f:
+ pickle.dump(self, f)
diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py
new file mode 100644
index 000000000..c0745f6d4
--- /dev/null
+++ b/qlib/workflow/__init__.py
@@ -0,0 +1,460 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+from contextlib import contextmanager
+from .expm import MLflowExpManager
+from .recorder import Recorder
+from ..utils import Wrapper
+
+
+class QlibRecorder:
+ """
+ A global system that helps to manage the experiments.
+ """
+
+ def __init__(self, exp_manager):
+ self.exp_manager = exp_manager
+
+ @contextmanager
+ def start(self, experiment_name=None, recorder_name=None):
+ """
+ Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:
+
+ .. code-block:: Python
+
+ with R.start('test', 'recorder_1'):
+ model.fit(dataset)
+ R.log...
+ ... # further operations
+
+ Parameters
+ ----------
+ experiment_name : str
+ name of the experiment one wants to start.
+ recorder_name : str
+ name of the recorder under the experiment one wants to start.
+ """
+ run = self.start_exp(experiment_name, recorder_name)
+ try:
+ yield run
+ except Exception as e:
+ self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong
+ raise e
+ self.end_exp(Recorder.STATUS_FI)
+
+ def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
+ """
+ Lower level method for starting an experiment. When use this method, one should end the experiment manually
+ and the status of the recorder may not be handled properly. Here is the example code:
+
+ .. code-block:: Python
+
+ R.start_exp(experiment_name='test', recorder_name='recorder_1')
+ ... # further operations
+ R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
+
+
+ Parameters
+ ----------
+ experiment_name : str
+ the name of the experiment to be started
+ recorder_name : str
+ name of the recorder under the experiment one wants to start.
+ uri : str
+ the tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
+ The default uri are set in the qlib.config.
+
+ Returns
+ -------
+ An experiment instance being started.
+ """
+ return self.exp_manager.start_exp(experiment_name, recorder_name, uri)
+
+ def end_exp(self, recorder_status=Recorder.STATUS_FI):
+ """
+ Method for ending an experiment manually. It will end the current active experiment, as well as its
+ active recorder with the specified `status` type. Here is the example code of the method:
+
+ .. code-block:: Python
+
+ R.start_exp(experiment_name='test')
+ ... # further operations
+ R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
+
+ Parameters
+ ----------
+ status : str
+ The status of a recorder, which can be SCHEDULED, RUNNING, FINISHED, FAILED.
+ """
+ self.exp_manager.end_exp(recorder_status)
+
+ def search_records(self, experiment_ids, **kwargs):
+ """
+ Get a pandas DataFrame of records that fit the search criteria. Here is the example code of the method:
+
+ .. code-block:: Python
+
+ R.log_metrics(m=2.50, step=0)
+ records = R.search_runs([experiment_id], order_by=["metrics.m DESC"])
+
+ Parameters
+ ----------
+ experiment_ids : list
+ list of experiment IDs.
+ filter_string : str
+ filter query string, defaults to searching all runs.
+ run_view_type : int
+ one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType).
+ max_results : int
+ the maximum number of runs to put in the dataframe.
+ order_by : list
+ list of columns to order by (e.g., “metrics.rmse”).
+
+ Returns
+ -------
+ A pandas.DataFrame of records, where each metric, parameter, and tag
+ are expanded into their own columns named metrics.*, params.*, and tags.*
+ respectively. For records that don't have a particular metric, parameter, or tag, their
+ value will be (NumPy) Nan, None, or None respectively.
+ """
+ return self.exp_manager.search_records(experiment_ids, **kwargs)
+
+ def list_experiments(self):
+ """
+ Method for listing all the existing experiments (except for those being deleted.)
+
+ .. code-block:: Python
+
+ exps = R.list_experiments()
+
+ Returns
+ -------
+ A dictionary (name -> experiment) of experiments information that being stored.
+ """
+ return self.exp_manager.list_experiments()
+
+ def list_recorders(self, experiment_id=None, experiment_name=None):
+ """
+ Method for listing all the recorders of experiment with given id or name.
+
+ If user doesn't provide the id or name of the experiment, this method will try to retrieve the default experiment and
+ list all the recorders of the default experiment. If the default experiment doesn't exist, the method will first
+ create the default experiment, and then create a new recorder under it.
+
+ Here is the example code:
+
+ .. code-block:: Python
+
+ recorders = R.list_recorders(experiment_name='test')
+
+ Parameters
+ ----------
+ experiment_id : str
+ id of the experiment.
+ experiment_name : str
+ name of the experiment.
+
+ Returns
+ -------
+ A dictionary (id -> recorder) of recorder information that being stored.
+ """
+ return self.get_exp(experiment_id, experiment_name).list_recorders()
+
+ def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
+ """
+ Method for retrieving an experiment with given id or name. Once the `create` argument is set to
+ True, if no valid experiment is found, this method will create one for you. Otherwise, it will
+ only retrieve a specific experiment or raise an Error.
+
+ - If '`create`' is True:
+
+ - If ``R``'s running:
+
+ - no id or name specified, return the active experiment.
+
+ - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be running.
+
+ - If ``R``'s not running:
+
+ - no id or name specified, create a default experiment, and the experiment is set to be running.
+
+ - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment, and the experiment is set to be running.
+
+ - Else If '`create`' is False:
+
+ - If ``R``'s running:
+
+ - no id or name specified, return the active experiment.
+
+ - if id or name is specified, return the specified experiment. If no such exp found, raise Error.
+
+ - If ``R``'s not running:
+
+ - no id or name specified. If the default experiment exists, return it, otherwise, raise Error.
+
+ - if id or name is specified, return the specified experiment. If no such exp found, raise Error.
+
+ Here are some use cases:
+
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ exp = R.get_exp()
+ recorders = exp.list_recorders()
+
+ # Case 2
+ with R.start('test'):
+ exp = R.get_exp('test1')
+
+ # Case 3
+ exp = R.get_exp() -> a default experiment.
+
+ # Case 4
+ exp = R.get_exp(experiment_name='test')
+
+ # Case 5
+ exp = R.get_exp(create=False) -> the default experiment if exists.
+
+ Parameters
+ ----------
+ experiment_id : str
+ id of the experiment.
+ experiment_name : str
+ name of the experiment.
+ create : boolean
+ an argument determines whether the method will automatically create a new experiment
+ according to user's specification if the experiment hasn't been created before.
+
+ Returns
+ -------
+ An experiment instance with given id or name.
+ """
+ return self.exp_manager.get_exp(experiment_id, experiment_name, create)
+
+ def delete_exp(self, experiment_id=None, experiment_name=None):
+ """
+ Method for deleting the experiment with given id or name. At least one of id or name must be given,
+ otherwise, error will occur.
+
+ Here is the example code:
+
+ .. code-block:: Python
+
+ R.delete_exp(experiment_name='test')
+
+ Parameters
+ ----------
+ experiment_id : str
+ id of the experiment.
+ experiment_name : str
+ name of the experiment.
+ """
+ self.exp_manager.delete_exp(experiment_id, experiment_name)
+
+ def get_uri(self):
+ """
+ Method for retrieving the uri of current experiment manager.
+
+ Here is the example code:
+
+ .. code-block:: Python
+
+ uri = R.get_uri()
+
+ Returns
+ -------
+ The uri of current experiment manager.
+ """
+ return self.exp_manager.get_uri()
+
+ def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
+ """
+ Method for retrieving a recorder.
+
+ - If ``R``'s running:
+
+ - no id or name specified, return the active recorder.
+
+ - if id or name is specified, return the specified recorder.
+
+ - If ``R``'s not running:
+
+ - no id or name specified, raise Error.
+
+ - if id or name is specified, and the corresponding experiment_name must be given, return the specified recorder. Otherwise, raise Error.
+
+ The recorder can be used for further process such as `save_object`, `load_object`, `log_params`,
+ `log_metrics`, etc.
+
+ Here are some use cases:
+
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ recorder = R.get_recorder()
+
+ # Case 2
+ with R.start('test'):
+ recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
+
+ # Case 3
+ recorder = R.get_recorder() -> Error
+
+ # Case 4
+ recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d') -> Error
+
+ # Case 5
+ recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')
+
+ Parameters
+ ----------
+ recorder_id : str
+ id of the recorder.
+ recorder_name : str
+ name of the recorder.
+ experiment_name : str
+ name of the experiment.
+
+ Returns
+ -------
+ A recorder instance.
+ """
+ return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
+ recorder_id, recorder_name, create=False
+ )
+
+ def delete_recorder(self, recorder_id=None, recorder_name=None):
+ """
+ Method for deleting the recorders with given id or name. At least one of id or name must be given,
+ otherwise, error will occur.
+
+ Here is the example code:
+
+ .. code-block:: Python
+
+ R.delete_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
+
+ Parameters
+ ----------
+ recorder_id : str
+ id of the experiment.
+ recorder_name : str
+ name of the experiment.
+ """
+ self.get_exp().delete_recorder(recorder_id, recorder_name)
+
+ def save_objects(self, local_path=None, artifact_path=None, **kwargs):
+ """
+ Method for saving objects as artifacts in the experiment to the uri. It supports either saving
+ from a local file/directory, or directly saving objects. User can use valid python's keywords arguments
+ to specify the object to be saved as well as its name (name: value).
+
+ - If R's running: it will save the objects through the running recorder.
+ - If R's not running: the system will create a default experiment, and a new recorder and save objects under it.
+
+ .. note::
+
+ If one wants to save objects with a specific recorder. It is recommended to first get the specific recorder through `get_recorder` API and use the recorder the save objects. The supported arguments are the same as this method.
+
+ Here are some use cases:
+
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ pred = model.predict(dataset)
+ R.save_objects(**{"pred.pkl": pred}, artifact_path='prediction')
+
+ # Case 2
+ with R.start('test'):
+ R.save_objects(local_path='results/pred.pkl')
+
+ Parameters
+ ----------
+ local_path : str
+ if provided, them save the file or directory to the artifact URI.
+ artifact_path : str
+ the relative path for the artifact to be stored in the URI.
+ """
+ self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs)
+
+ def log_params(self, **kwargs):
+ """
+ Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
+
+ - If R's running: it will log parameters through the running recorder.
+ - If R's not running: the system will create a default experiment as well as a new recorder, and log parameters under it.
+
+ Here are some use cases:
+
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ R.log_params(learning_rate=0.01)
+
+ # Case 2
+ R.log_params(learning_rate=0.01)
+
+ Parameters
+ ----------
+ keyword argument:
+ name1=value1, name2=value2, ...
+ """
+ self.get_exp().get_recorder().log_params(**kwargs)
+
+ def log_metrics(self, step=None, **kwargs):
+ """
+ Method for logging metrics during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
+
+ - If R's running: it will log metrics through the running recorder.
+ - If R's not running: the system will create a default experiment as well as a new recorder, and log metrics under it.
+
+ Here are some use cases:
+
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ R.log_metrics(train_loss=0.33, step=1)
+
+ # Case 2
+ R.log_metrics(train_loss=0.33, step=1)
+
+ Parameters
+ ----------
+ keyword argument:
+ name1=value1, name2=value2, ...
+ """
+ self.get_exp().get_recorder().log_metrics(step, **kwargs)
+
+ def set_tags(self, **kwargs):
+ """
+ Method for setting tags for a recorder. In addition to using ``R``, one can also set the tag to a specific recorder after getting it with `get_recorder` API.
+
+ - If R's running: it will set tags through the running recorder.
+ - If R's not running: the system will create a default experiment as well as a new recorder, and set the tags under it.
+
+ Here are some use cases:
+
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ R.set_tags(release_version="2.2.0")
+
+ # Case 2
+ R.set_tags(release_version="2.2.0")
+
+ Parameters
+ ----------
+ keyword argument:
+ name1=value1, name2=value2, ...
+ """
+ self.get_exp().get_recorder().set_tags(**kwargs)
+
+
+# global record
+R = Wrapper()
diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py
new file mode 100644
index 000000000..65d9a14b4
--- /dev/null
+++ b/qlib/workflow/cli.py
@@ -0,0 +1,67 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import sys, os
+from pathlib import Path
+
+import qlib
+import fire
+import pandas as pd
+import ruamel.yaml as yaml
+from qlib.config import C
+from qlib.model.trainer import task_train
+
+
+def get_path_list(path):
+ if isinstance(path, str):
+ return [path]
+ else:
+ return [p for p in path]
+
+
+def sys_config(config, config_path):
+ """
+ Configure the `sys` section
+
+ Parameters
+ ----------
+ config : dict
+ configuration of the workflow.
+ config_path : str
+ path of the configuration
+ """
+ sys_config = config.get("sys", {})
+
+ # abspath
+ for p in get_path_list(sys_config.get("path", [])):
+ sys.path.append(p)
+
+ # relative path to config path
+ for p in get_path_list(sys_config.get("rel_path", [])):
+ sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
+
+
+# worflow handler function
+def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
+ with open(config_path) as fp:
+ config = yaml.load(fp, Loader=yaml.Loader)
+
+ # config the `sys` section
+ sys_config(config, config_path)
+
+ provider_uri = config.get("provider_uri")
+ region = config.get("region")
+ exp_manager = C["exp_manager"]
+ exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder)
+ qlib.init(provider_uri=provider_uri, region=region, exp_manager=exp_manager)
+
+ task_train(config, experiment_name=experiment_name)
+
+
+# function to run worklflow by config
+def run():
+ fire.Fire(workflow)
+
+
+if __name__ == "__main__":
+ run()
diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py
new file mode 100644
index 000000000..09c680e59
--- /dev/null
+++ b/qlib/workflow/exp.py
@@ -0,0 +1,289 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import mlflow
+from mlflow.entities import ViewType
+from mlflow.exceptions import MlflowException
+from pathlib import Path
+from .recorder import Recorder, MLflowRecorder
+from ..log import get_module_logger
+
+logger = get_module_logger("workflow", "INFO")
+
+
+class Experiment:
+ """
+ This is the `Experiment` class for each experiment being run. The API is designed similar to mlflow.
+ (The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
+ """
+
+ def __init__(self, id, name):
+ self.id = id
+ self.name = name
+ self.active_recorder = None # only one recorder can running each time
+
+ def __repr__(self):
+ return str(self.info)
+
+ def __str__(self):
+ return str(self.info)
+
+ @property
+ def info(self):
+ recorders = self.list_recorders()
+ output = dict()
+ output["class"] = "Experiment"
+ output["id"] = self.id
+ output["name"] = self.name
+ output["active_recorder"] = self.active_recorder.id if self.active_recorder is not None else None
+ output["recorders"] = list(recorders.keys())
+ return output
+
+ def start(self, recorder_name=None):
+ """
+ Start the experiment and set it to be active. This method will also start a new recorder.
+
+ Parameters
+ ----------
+ recorder_name : str
+ the name of the recorder to be created.
+
+ Returns
+ -------
+ An active recorder.
+ """
+ raise NotImplementedError(f"Please implement the `start` method.")
+
+ def end(self, recorder_status=Recorder.STATUS_S):
+ """
+ End the experiment.
+
+ Parameters
+ ----------
+ recorder_status : str
+ the status the recorder to be set with when ending (SCHEDULED, RUNNING, FINISHED, FAILED).
+ """
+ raise NotImplementedError(f"Please implement the `end` method.")
+
+ def create_recorder(self, name=None):
+ """
+ Create a recorder for each experiment.
+
+ Parameters
+ ----------
+ name : str
+ the name of the recorder to be created.
+
+ Returns
+ -------
+ A recorder object.
+ """
+ raise NotImplementedError(f"Please implement the `create_recorder` method.")
+
+ def search_records(self, **kwargs):
+ """
+ Get a pandas DataFrame of records that fit the search criteria of the experiment.
+ Inputs are the search critera user want to apply.
+
+ Returns
+ -------
+ A pandas.DataFrame of records, where each metric, parameter, and tag
+ are expanded into their own columns named metrics.*, params.*, and tags.*
+ respectively. For records that don't have a particular metric, parameter, or tag, their
+ value will be (NumPy) Nan, None, or None respectively.
+ """
+ raise NotImplementedError(f"Please implement the `search_records` method.")
+
+ def delete_recorder(self, recorder_id):
+ """
+ Create a recorder for each experiment.
+
+ Parameters
+ ----------
+ recorder_id : str
+ the id of the recorder to be deleted.
+ """
+ raise NotImplementedError(f"Please implement the `delete_recorder` method.")
+
+ def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True):
+ """
+ Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the
+ specific recorder. When user does not provide recorder id or name, the method will try to return the current
+ active recorder. The `create` argument determines whether the method will automatically create a new recorder
+ according to user's specification if the recorder hasn't been created before
+
+ * If `create` is True:
+
+ * If R's running:
+
+ * no id or name specified, return the active recorder.
+ * if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name, and the recorder shoud be running.
+
+ * If R's not running:
+
+ * no id or name specified, create a new recorder.
+ * if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name, and the recorder shoud be running.
+
+ * Else If `create` is False:
+
+ * If R's running:
+
+ * no id or name specified, return the active recorder.
+ * if id or name is specified, return the specified recorder. If no such exp found, raise Error.
+
+ * If R's not running:
+
+ * no id or name specified, raise Error.
+ * if id or name is specified, return the specified recorder. If no such exp found, raise Error.
+
+ Parameters
+ ----------
+ recorder_id : str
+ the id of the recorder to be deleted.
+ recorder_name : str
+ the name of the recorder to be deleted.
+ create : boolean
+ create the recorder if it hasn't been created before.
+
+ Returns
+ -------
+ A recorder object.
+ """
+ raise NotImplementedError(f"Please implement the `get_recorder` method.")
+
+ def list_recorders(self):
+ """
+ List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
+ If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
+
+ Returns
+ -------
+ A dictionary (id -> recorder) of recorder information that being stored.
+ """
+ raise NotImplementedError(f"Please implement the `list_recorders` method.")
+
+
+class MLflowExperiment(Experiment):
+ """
+ Use mlflow to implement Experiment.
+ """
+
+ def __init__(self, id, name, uri):
+ super(MLflowExperiment, self).__init__(id, name)
+ self._uri = uri
+ self._default_name = None
+ self._default_rec_name = "mlflow_recorder"
+ self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
+
+ def start(self, recorder_name=None):
+ # set the active experiment
+ mlflow.set_experiment(self.name)
+ logger.info(f"Experiment {self.id} starts running ...")
+ # set up recorder
+ recorder = self.create_recorder(recorder_name)
+ self.active_recorder = recorder
+ # start the recorder
+ self.active_recorder.start_run()
+
+ return self.active_recorder
+
+ def end(self, recorder_status):
+ if self.active_recorder is not None:
+ self.active_recorder.end_run(recorder_status)
+ self.active_recorder = None
+
+ def create_recorder(self, recorder_name=None):
+ if recorder_name is None:
+ recorder_name = self._default_rec_name
+ recorder = MLflowRecorder(self.id, self._uri, recorder_name)
+
+ return recorder
+
+ def get_recorder(self, recorder_id=None, recorder_name=None, create=True):
+ # special case of getting the recorder
+ if recorder_id is None and recorder_name is None:
+ if self.active_recorder is not None:
+ return self.active_recorder
+ recorder_name = self._default_rec_name
+ if create:
+ recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
+ else:
+ recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
+ if is_new:
+ mlflow.set_experiment(self.name)
+ self.active_recorder = recorder
+ # start the recorder
+ self.active_recorder.start_run()
+ return recorder
+
+ def _get_or_create_rec(self, recorder_id=None, recorder_name=None) -> (object, bool):
+ """
+ Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
+ automatically create a new recorder based on the given id and name.
+ """
+ try:
+ return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
+ except ValueError:
+ if recorder_name is None:
+ recorder_name = self._default_rec_name
+ logger.info(f"No valid recorder found. Create a new recorder with name {recorder_name}.")
+ return self.create_recorder(recorder_name), True
+
+ def _get_recorder(self, recorder_id=None, recorder_name=None):
+ """
+ Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
+ raise errors.
+ """
+ assert (
+ recorder_id is not None or recorder_name is not None
+ ), "Please input at least one of recorder id or name before retrieving recorder."
+ if recorder_id is not None:
+ try:
+ run = self.client.get_run(recorder_id)
+ recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
+ return recorder
+ except MlflowException:
+ raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.")
+ elif recorder_name is not None:
+ logger.warning(
+ f"Please make sure the recorder name {recorder_name} is unique, we will only return the first recorder if there exist several matched the given name."
+ )
+ recorders = self.list_recorders()
+ for rid in recorders:
+ if recorders[rid].name == recorder_name:
+ return recorders[rid]
+ raise ValueError("No valid recorder has been found, please make sure the input recorder name is correct.")
+
+ def search_records(self, **kwargs):
+ filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
+ run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
+ max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
+ order_by = kwargs.get("order_by")
+
+ return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
+
+ def delete_recorder(self, recorder_id=None, recorder_name=None):
+ assert (
+ recorder_id is not None or recorder_name is not None
+ ), "Please input a valid recorder id or name before deleting."
+ try:
+ if recorder_id is not None:
+ self.client.delete_run(recorder_id)
+ else:
+ recorder = self._get_recorder(recorder_name=recorder_name)
+ self.client.delete_run(recorder.id)
+ except MlflowException as e:
+ raise Exception(
+ f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
+ )
+
+ UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
+
+ def list_recorders(self, max_results=UNLIMITED):
+ runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
+ recorders = dict()
+ for i in range(len(runs)):
+ recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
+ recorders[runs[i].info.run_id] = recorder
+
+ return recorders
diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py
new file mode 100644
index 000000000..cfb0290fc
--- /dev/null
+++ b/qlib/workflow/expm.py
@@ -0,0 +1,333 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import mlflow
+from mlflow.exceptions import MlflowException
+from mlflow.entities import ViewType
+import os
+from pathlib import Path
+from contextlib import contextmanager
+from .exp import MLflowExperiment, Experiment
+from .recorder import Recorder, MLflowRecorder
+from ..log import get_module_logger
+
+logger = get_module_logger("workflow", "INFO")
+
+
+class ExpManager:
+ """
+ This is the `ExpManager` class for managing experiments. The API is designed similar to mlflow.
+ (The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
+ """
+
+ def __init__(self, uri, default_exp_name):
+ self.uri = uri
+ self.default_exp_name = default_exp_name
+ self.active_experiment = None # only one experiment can running each time
+
+ def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
+ """
+ Start an experiment. This method includes first get_or_create an experiment, and then
+ set it to be running.
+
+ Parameters
+ ----------
+ experiment_name : str
+ name of the active experiment.
+ recorder_name : str
+ name of the recorder to be started.
+ uri : str
+ the current tracking URI.
+
+ Returns
+ -------
+ An active experiment.
+ """
+ raise NotImplementedError(f"Please implement the `start_exp` method.")
+
+ def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
+ """
+ End an running experiment.
+
+ Parameters
+ ----------
+ experiment_name : str
+ name of the active experiment.
+ recorder_status : str
+ the status of the active recorder of the experiment.
+ """
+ raise NotImplementedError(f"Please implement the `end_exp` method.")
+
+ def create_exp(self, experiment_name=None):
+ """
+ Create an experiment.
+
+ Parameters
+ ----------
+ experiment_name : str
+ the experiment name, which must be unique.
+
+ Returns
+ -------
+ An experiment object.
+ """
+ raise NotImplementedError(f"Please implement the `create_exp` method.")
+
+ def search_records(self, experiment_ids=None, **kwargs):
+ """
+ Get a pandas DataFrame of records that fit the search criteria of the experiment.
+ Inputs are the search critera user want to apply.
+
+ Returns
+ -------
+ A pandas.DataFrame of records, where each metric, parameter, and tag
+ are expanded into their own columns named metrics.*, params.*, and tags.*
+ respectively. For records that don't have a particular metric, parameter, or tag, their
+ value will be (NumPy) Nan, None, or None respectively.
+ """
+ raise NotImplementedError(f"Please implement the `search_records` method.")
+
+ def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
+ """
+ Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
+ The returned experiment will be running.
+
+ When user specify experiment id and name, the method will try to return the specific experiment.
+ When user does not provide recorder id or name, the method will try to return the current active experiment.
+ The `create` argument determines whether the method will automatically create a new experiment according
+ to user's specification if the experiment hasn't been created before.
+
+ * If `create` is True:
+
+ * If R's running:
+
+ * no id or name specified, return the active experiment.
+ * if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be running.
+
+ * If R's not running:
+
+ * no id or name specified, create a default experiment.
+ * if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be running.
+
+ * Else If `create` is False:
+
+ * If R's running:
+
+ * no id or name specified, return the active experiment.
+ * if id or name is specified, return the specified experiment. If no such exp found, raise Error.
+
+ * If R's not running:
+
+ * no id or name specified. If the default experiment exists, return it, otherwise, raise Error.
+ * if id or name is specified, return the specified experiment. If no such exp found, raise Error.
+
+ Parameters
+ ----------
+ experiment_id : str
+ id of the experiment to return.
+ experiment_name : str
+ name of the experiment to return.
+ create : boolean
+ create the experiment it if hasn't been created before.
+
+ Returns
+ -------
+ An experiment object.
+ """
+ # special case of getting experiment
+ if experiment_id is None and experiment_name is None:
+ if self.active_experiment is not None:
+ return self.active_experiment
+ # User don't want get active code now.
+ # Don't assume underlying code could handle the case of two None
+ if experiment_id is None and experiment_name is None:
+ experiment_name = self.default_exp_name
+
+ if create:
+ exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
+ else:
+ exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
+ if is_new:
+ self.active_experiment = exp
+ # start the recorder
+ self.active_experiment.start()
+ return exp
+
+ def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (object, bool):
+ """
+ Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will
+ automatically create a new experiment based on the given id and name.
+ """
+ try:
+ if experiment_id is None and experiment_name is None:
+ experiment_name = self.default_exp_name
+ return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
+ except ValueError:
+ if experiment_name is None:
+ experiment_name = self.default_exp_name
+ logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
+ return self.create_exp(experiment_name), True
+
+ def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
+ """
+ get specific experiment by name or id. If it does not exist, raise ValueError
+
+ Parameters
+ ----------
+ experiment_id :
+ The id of experiment
+ experiment_name :
+ The id name experiment
+
+ Returns
+ -------
+ Experiment:
+ The searched experiment
+
+ Raises
+ ------
+ ValueError
+ """
+ raise NotImplementedError(f"Please implement the `_get_exp` method")
+
+ def delete_exp(self, experiment_id=None, experiment_name=None):
+ """
+ Delete an experiment.
+
+ Parameters
+ ----------
+ experiment_id : str
+ the experiment id.
+ experiment_name : str
+ the experiment name.
+ """
+ raise NotImplementedError(f"Please implement the `delete_exp` method.")
+
+ def get_uri(self):
+ """
+ Get the default tracking URI or current URI.
+
+ Returns
+ -------
+ The tracking URI string.
+ """
+ return self.uri
+
+ def list_experiments(self):
+ """
+ List all the existing experiments.
+
+ Returns
+ -------
+ A dictionary (name -> experiment) of experiments information that being stored.
+ """
+ raise NotImplementedError(f"Please implement the `list_experiments` method.")
+
+
+class MLflowExpManager(ExpManager):
+ """
+ Use mlflow to implement ExpManager.
+ """
+
+ def __init__(self, uri, default_exp_name):
+ super(MLflowExpManager, self).__init__(uri, default_exp_name)
+
+ @property
+ def client(self):
+ # Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
+ if not hasattr(self, "_client"):
+ self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
+ return self._client
+
+ def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
+ # set the tracking uri
+ if uri is None:
+ logger.info("No tracking URI is provided. Use the default tracking URI.")
+ else:
+ self.uri = uri
+ # create experiment
+ experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
+ # set up active experiment
+ self.active_experiment = experiment
+ # start the experiment
+ self.active_experiment.start(recorder_name)
+
+ return self.active_experiment
+
+ def end_exp(self, recorder_status: str = Recorder.STATUS_S):
+ if self.active_experiment is not None:
+ self.active_experiment.end(recorder_status)
+ self.active_experiment = None
+
+ def create_exp(self, experiment_name=None):
+ assert experiment_name is not None
+ # init experiment
+ experiment_id = self.client.create_experiment(experiment_name)
+ experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
+ experiment._default_name = self.default_exp_name
+
+ return experiment
+
+ def _get_exp(self, experiment_id=None, experiment_name=None):
+ """
+ Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will
+ raise errors.
+ """
+ assert (
+ experiment_id is not None or experiment_name is not None
+ ), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
+ if experiment_id is not None:
+ try:
+ exp = self.client.get_experiment(experiment_id)
+ if exp.lifecycle_stage.upper() == "DELETED":
+ raise MlflowException("No valid experiment has been found.")
+ experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
+ return experiment
+ except MlflowException:
+ raise ValueError(
+ "No valid experiment has been found, please make sure the input experiment id is correct."
+ )
+ elif experiment_name is not None:
+ try:
+ exp = self.client.get_experiment_by_name(experiment_name)
+ if exp is None or exp.lifecycle_stage.upper() == "DELETED":
+ raise MlflowException("No valid experiment has been found.")
+ experiment = MLflowExperiment(exp.experiment_id, experiment_name, self.uri)
+ return experiment
+ except MlflowException as e:
+ raise ValueError(
+ "No valid experiment has been found, please make sure the input experiment name is correct."
+ )
+
+ def search_records(self, experiment_ids, **kwargs):
+ filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
+ run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
+ max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
+ order_by = kwargs.get("order_by")
+ return self.client.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by)
+
+ def delete_exp(self, experiment_id=None, experiment_name=None):
+ assert (
+ experiment_id is not None or experiment_name is not None
+ ), "Please input a valid experiment id or name before deleting."
+ try:
+ if experiment_id is not None:
+ self.client.delete_experiment(experiment_id)
+ else:
+ experiment = self.client.get_experiment_by_name(experiment_name)
+ if experiment is None:
+ raise MlflowException("No valid experiment has been found.")
+ self.client.delete_experiment(experiment.experiment_id)
+ except MlflowException as e:
+ raise Exception(
+ f"Error: {e}. Something went wrong when deleting experiment. Please check if the name/id of the experiment is correct."
+ )
+
+ def list_experiments(self):
+ # retrieve all the existing experiments
+ exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY)
+ experiments = dict()
+ for exp in exps:
+ experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
+ experiments[exp.name] = experiment
+ return experiments
diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py
new file mode 100644
index 000000000..ec76343bd
--- /dev/null
+++ b/qlib/workflow/record_temp.py
@@ -0,0 +1,262 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import re
+import pandas as pd
+from pathlib import Path
+from pprint import pprint
+from ..contrib.evaluate import (
+ backtest as normal_backtest,
+ risk_analysis,
+)
+from ..data.dataset import DatasetH
+from ..data.dataset.handler import DataHandlerLP
+from ..utils import init_instance_by_config, get_module_by_module_path
+from ..log import get_module_logger
+from ..utils import flatten_dict
+from ..contrib.eva.alpha import calc_ic, calc_long_short_return
+from ..contrib.strategy.strategy import BaseStrategy
+
+logger = get_module_logger("workflow", "INFO")
+
+
+class RecordTemp:
+ """
+ This is the Records Template class that enables user to generate experiment results such as IC and
+ backtest in a certain format.
+ """
+
+ artifact_path = None
+
+ @classmethod
+ def get_path(cls, path=None):
+ names = []
+ if cls.artifact_path is not None:
+ names.append(cls.artifact_path)
+
+ if path is not None:
+ names.append(path)
+
+ return "/".join(names)
+
+ def __init__(self, recorder):
+ self.recorder = recorder
+
+ def generate(self, **kwargs):
+ """
+ Generate certain records such as IC, backtest etc., and save them.
+
+ Parameters
+ ----------
+ kwargs
+
+ Return
+ ------
+ """
+ raise NotImplementedError(f"Please implement the `generate` method.")
+
+ def load(self, name):
+ """
+ Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API
+ with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them
+ in the future::
+
+ sar = SigAnaRecord(recorder)
+ ic = sar.load(sar.get_path("ic.pkl"))
+
+ Parameters
+ ----------
+ name : str
+ the name for the file to be load.
+
+ Return
+ ------
+ The stored records.
+ """
+ # try to load the saved object
+ obj = self.recorder.load_object(name)
+ return obj
+
+ def list(self):
+ """
+ List the stored records.
+
+ Return
+ ------
+ A list of all the stored records.
+ """
+ return []
+
+ def check(self, parent=False):
+ """
+ Check if the records is properly generated and saved.
+
+ Raise
+ ------
+ FileExistsError: whether the records are stored properly.
+ """
+ artifacts = set(self.recorder.list_artifacts())
+ if parent:
+ # Downcasting have to be done here instead of using `super`
+ flist = self.__class__.__base__.list(self) # pylint: disable=E1101
+ else:
+ flist = self.list()
+ for item in flist:
+ if item not in artifacts:
+ raise FileExistsError(item)
+
+
+class SignalRecord(RecordTemp):
+ """
+ This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
+ """
+
+ def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
+ super().__init__(recorder=recorder)
+ self.model = model
+ self.dataset = dataset
+
+ def generate(self, **kwargs):
+ # generate prediciton
+ pred = self.model.predict(self.dataset)
+ if isinstance(pred, pd.Series):
+ pred = pred.to_frame("score")
+ self.recorder.save_objects(**{"pred.pkl": pred})
+
+ logger.info(
+ f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
+ )
+ # print out results
+ pprint(f"The following are prediction results of the {type(self.model).__name__} model.")
+ pprint(pred.head(5))
+
+ # save according label
+ if isinstance(self.dataset, DatasetH):
+ params = dict(self=self.dataset, segments="test", col_set="label", data_key=DataHandlerLP.DK_R)
+ try:
+ # Assume the backend handler is DataHandlerLP
+ raw_label = DatasetH.prepare(**params)
+ except TypeError:
+ # The argument number is not right
+ del params["data_key"]
+ # The backend handler should be DataHandler
+ raw_label = DatasetH.prepare(**params)
+ self.recorder.save_objects(**{"label.pkl": raw_label})
+
+ def list(self):
+ return ["pred.pkl", "label.pkl"]
+
+ def load(self, name="pred.pkl"):
+ return super().load(name)
+
+
+class SigAnaRecord(SignalRecord):
+ """
+ This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
+ """
+
+ artifact_path = "sig_analysis"
+
+ def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
+ self.ana_long_short = ana_long_short
+ self.ann_scaler = ann_scaler
+ super().__init__(recorder=recorder, **kwargs)
+ # The name must be unique. Otherwise it will be overridden
+
+ def generate(self):
+ self.check(parent=True)
+
+ pred = self.load("pred.pkl")
+ label = self.load("label.pkl")
+ ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
+ metrics = {
+ "IC": ic.mean(),
+ "ICIR": ic.mean() / ic.std(),
+ "Rank IC": ric.mean(),
+ "Rank ICIR": ric.mean() / ric.std(),
+ }
+ objects = {"ic.pkl": ic, "ric.pkl": ric}
+ if self.ana_long_short:
+ long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
+ metrics.update(
+ {
+ "Long-Short Ann Return": long_short_r.mean() * self.ann_scaler,
+ "Long-Short Ann Sharpe": long_short_r.mean() / long_short_r.std() * self.ann_scaler ** 0.5,
+ "Long-Avg Ann Return": long_avg_r.mean() * self.ann_scaler,
+ "Long-Avg Ann Sharpe": long_avg_r.mean() / long_avg_r.std() * self.ann_scaler ** 0.5,
+ }
+ )
+ objects.update(
+ {
+ "long_short_r.pkl": long_short_r,
+ "long_avg_r.pkl": long_avg_r,
+ }
+ )
+ self.recorder.log_metrics(**metrics)
+ self.recorder.save_objects(**objects, artifact_path=self.get_path())
+ pprint(metrics)
+
+ def list(self):
+ paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
+ if self.ana_long_short:
+ paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
+ return paths
+
+
+class PortAnaRecord(SignalRecord):
+ """
+ This is the Portfolio Analysis Record class that generates the analysis results such as those of backtest. This class inherits the ``RecordTemp`` class.
+ """
+
+ artifact_path = "portfolio_analysis"
+
+ def __init__(self, recorder, config, **kwargs):
+ """
+ config["strategy"] : dict
+ define the strategy class as well as the kwargs.
+ config["backtest"] : dict
+ define the backtest kwargs.
+ """
+ super().__init__(recorder=recorder)
+
+ self.strategy_config = config["strategy"]
+ self.backtest_config = config["backtest"]
+ self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy)
+
+ def generate(self, **kwargs):
+ # check previously stored prediction results
+ self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
+
+ # custom strategy and get backtest
+ pred_score = super().load()
+ report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
+ self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
+ self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
+
+ # analysis
+ analysis = dict()
+ analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
+ analysis["excess_return_with_cost"] = risk_analysis(
+ report_normal["return"] - report_normal["bench"] - report_normal["cost"]
+ )
+ # save portfolio analysis results
+ analysis_df = pd.concat(analysis) # type: pd.DataFrame
+ # log metrics
+ self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
+ # save results
+ self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path())
+ logger.info(
+ f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
+ )
+ # print out results
+ pprint("The following are analysis results of the excess return without cost.")
+ pprint(analysis["excess_return_without_cost"])
+ pprint("The following are analysis results of the excess return with cost.")
+ pprint(analysis["excess_return_with_cost"])
+
+ def list(self):
+ return [
+ PortAnaRecord.get_path("report_normal.pkl"),
+ PortAnaRecord.get_path("positions_normal.pkl"),
+ PortAnaRecord.get_path("port_analysis.pkl"),
+ ]
diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py
new file mode 100644
index 000000000..4c1ddfdfe
--- /dev/null
+++ b/qlib/workflow/recorder.py
@@ -0,0 +1,307 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import mlflow
+import shutil, os, pickle, tempfile, codecs
+from pathlib import Path
+from datetime import datetime
+from ..utils.objm import FileManager
+from ..log import get_module_logger
+
+logger = get_module_logger("workflow", "INFO")
+
+
+class Recorder:
+ """
+ This is the `Recorder` class for logging the experiments. The API is designed similar to mlflow.
+ (The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
+
+ The status of the recorder can be SCHEDULED, RUNNING, FINISHED, FAILED.
+ """
+
+ # status type
+ STATUS_S = "SCHEDULED"
+ STATUS_R = "RUNNING"
+ STATUS_FI = "FINISHED"
+ STATUS_FA = "FAILED"
+
+ def __init__(self, experiment_id, name):
+ self.id = None
+ self.name = name
+ self.experiment_id = experiment_id
+ self.start_time = None
+ self.end_time = None
+ self.status = Recorder.STATUS_S
+
+ def __repr__(self):
+ return str(self.info)
+
+ def __str__(self):
+ return str(self.info)
+
+ @property
+ def info(self):
+ output = dict()
+ output["class"] = "Recorder"
+ output["id"] = self.id
+ output["name"] = self.name
+ output["experiment_id"] = self.experiment_id
+ output["start_time"] = self.start_time
+ output["end_time"] = self.end_time
+ output["status"] = self.status
+ return output
+
+ def set_recorder_name(self, rname):
+ self.recorder_name = rname
+
+ def save_objects(self, local_path=None, artifact_path=None, **kwargs):
+ """
+ Save objects such as prediction file or model checkpoints to the artifact URI. User
+ can save object through keywords arguments (name:value).
+
+ Parameters
+ ----------
+ local_path : str
+ if provided, them save the file or directory to the artifact URI.
+ artifact_path=None : str
+ the relative path for the artifact to be stored in the URI.
+ """
+ raise NotImplementedError(f"Please implement the `save_objects` method.")
+
+ def load_object(self, name):
+ """
+ Load objects such as prediction file or model checkpoints.
+
+ Parameters
+ ----------
+ name : str
+ name of the file to be loaded.
+
+ Returns
+ -------
+ The saved object.
+ """
+ raise NotImplementedError(f"Please implement the `load_object` method.")
+
+ def start_run(self):
+ """
+ Start running or resuming the Recorder. The return value can be used as a context manager within a `with` block;
+ otherwise, you must call end_run() to terminate the current run. (See `ActiveRun` class in mlflow)
+
+ Returns
+ -------
+ An active running object (e.g. mlflow.ActiveRun object).
+ """
+ raise NotImplementedError(f"Please implement the `start_run` method.")
+
+ def end_run(self):
+ """
+ End an active Recorder.
+ """
+ raise NotImplementedError(f"Please implement the `end_run` method.")
+
+ def log_params(self, **kwargs):
+ """
+ Log a batch of params for the current run.
+
+ Parameters
+ ----------
+ keyword arguments
+ key, value pair to be logged as parameters.
+ """
+ raise NotImplementedError(f"Please implement the `log_params` method.")
+
+ def log_metrics(self, step=None, **kwargs):
+ """
+ Log multiple metrics for the current run.
+
+ Parameters
+ ----------
+ keyword arguments
+ key, value pair to be logged as metrics.
+ """
+ raise NotImplementedError(f"Please implement the `log_metrics` method.")
+
+ def set_tags(self, **kwargs):
+ """
+ Log a batch of tags for the current run.
+
+ Parameters
+ ----------
+ keyword arguments
+ key, value pair to be logged as tags.
+ """
+ raise NotImplementedError(f"Please implement the `set_tags` method.")
+
+ def delete_tags(self, *keys):
+ """
+ Delete some tags from a run.
+
+ Parameters
+ ----------
+ keys : series of strs of the keys
+ all the name of the tag to be deleted.
+ """
+ raise NotImplementedError(f"Please implement the `delete_tags` method.")
+
+ def list_artifacts(self, artifact_path: str = None):
+ """
+ List all the artifacts of a recorder.
+
+ Parameters
+ ----------
+ artifact_path : str
+ the relative path for the artifact to be stored in the URI.
+
+ Returns
+ -------
+ A list of artifacts information (name, path, etc.) that being stored.
+ """
+ raise NotImplementedError(f"Please implement the `list_artifacts` method.")
+
+ def list_metrics(self):
+ """
+ List all the metrics of a recorder.
+
+ Returns
+ -------
+ A dictionary of metrics that being stored.
+ """
+ raise NotImplementedError(f"Please implement the `list_metrics` method.")
+
+ def list_params(self):
+ """
+ List all the params of a recorder.
+
+ Returns
+ -------
+ A dictionary of params that being stored.
+ """
+ raise NotImplementedError(f"Please implement the `list_params` method.")
+
+ def list_tags(self):
+ """
+ List all the tags of a recorder.
+
+ Returns
+ -------
+ A dictionary of tags that being stored.
+ """
+ raise NotImplementedError(f"Please implement the `list_tags` method.")
+
+
+class MLflowRecorder(Recorder):
+ """
+ Use mlflow to implement a Recorder.
+
+ Due to the fact that mlflow will only log artifact from a file or directory, we decide to
+ use file manager to help maintain the objects in the project.
+ """
+
+ def __init__(self, experiment_id, uri, name=None, mlflow_run=None):
+ super(MLflowRecorder, self).__init__(experiment_id, name)
+ self._uri = uri
+ self.artifact_uri = None
+ # set up file manager for saving objects
+ self.temp_dir = tempfile.mkdtemp()
+ self.fm = FileManager(Path(self.temp_dir).absolute())
+ self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
+ # construct from mlflow run
+ if mlflow_run is not None:
+ assert isinstance(mlflow_run, mlflow.entities.run.Run), "Please input with a MLflow Run object."
+ self.name = mlflow_run.data.tags["mlflow.runName"]
+ self.id = mlflow_run.info.run_id
+ self.status = mlflow_run.info.status
+ self.start_time = (
+ datetime.fromtimestamp(float(mlflow_run.info.start_time) / 1000.0).strftime("%Y-%m-%d %H:%M:%S")
+ if mlflow_run.info.start_time is not None
+ else None
+ )
+ self.end_time = (
+ datetime.fromtimestamp(float(mlflow_run.info.end_time) / 1000.0).strftime("%Y-%m-%d %H:%M:%S")
+ if mlflow_run.info.end_time is not None
+ else None
+ )
+
+ def start_run(self):
+ # set the tracking uri
+ mlflow.set_tracking_uri(self._uri)
+ # start the run
+ run = mlflow.start_run(self.id, self.experiment_id, self.name)
+ # save the run id and artifact_uri
+ self.id = run.info.run_id
+ self.artifact_uri = run.info.artifact_uri
+ self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ self.status = Recorder.STATUS_R
+ logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
+
+ return run
+
+ def end_run(self, status: str = Recorder.STATUS_S):
+ assert status in [
+ Recorder.STATUS_S,
+ Recorder.STATUS_R,
+ Recorder.STATUS_FI,
+ Recorder.STATUS_FA,
+ ], f"The status type {status} is not supported."
+ mlflow.end_run(status)
+ self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ if self.status != Recorder.STATUS_S:
+ self.status = status
+ shutil.rmtree(self.temp_dir)
+
+ def save_objects(self, local_path=None, artifact_path=None, **kwargs):
+ assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
+ if local_path is not None:
+ self.client.log_artifacts(self.id, local_path, artifact_path)
+ else:
+ for name, data in kwargs.items():
+ self.fm.save_obj(data, name)
+ self.client.log_artifact(self.id, self.fm.path / name, artifact_path)
+
+ def load_object(self, name):
+ assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
+ path = self.client.download_artifacts(self.id, name)
+ with Path(path).open("rb") as f:
+ return pickle.load(f)
+
+ def log_params(self, **kwargs):
+ for name, data in kwargs.items():
+ self.client.log_param(self.id, name, data)
+
+ def log_metrics(self, step=None, **kwargs):
+ for name, data in kwargs.items():
+ self.client.log_metric(self.id, name, data, step=step)
+
+ def set_tags(self, **kwargs):
+ for name, data in kwargs.items():
+ self.client.set_tag(self.id, name, data)
+
+ def delete_tags(self, *keys):
+ for key in keys:
+ self.client.delete_tag(self.id, key)
+
+ def get_artifact_uri(self):
+ if self.artifact_uri is not None:
+ return self.artifact_uri
+ else:
+ raise Exception(
+ "Please make sure the recorder has been created and started properly before getting artifact uri."
+ )
+
+ def list_artifacts(self, artifact_path=None):
+ assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
+ artifacts = self.client.list_artifacts(self.id, artifact_path)
+ return [art.path for art in artifacts]
+
+ def list_metrics(self):
+ run = self.client.get_run(self.id)
+ return run.data.metrics
+
+ def list_params(self):
+ run = self.client.get_run(self.id)
+ return run.data.params
+
+ def list_tags(self):
+ run = self.client.get_run(self.id)
+ return run.data.tags
diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py
new file mode 100644
index 000000000..33d251dd8
--- /dev/null
+++ b/qlib/workflow/utils.py
@@ -0,0 +1,48 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import sys, traceback, signal, atexit
+from . import R
+from .recorder import Recorder
+from ..log import get_module_logger
+
+logger = get_module_logger("workflow", "INFO")
+
+
+# function to handle the experiment when unusual program ending occurs
+def experiment_exit_handler():
+ """
+ Method for handling the experiment when any unusual program ending occurs.
+ The `atexit` handler should be put in the last, since, as long as the program ends, it will be called.
+ Thus, if any exception or user interuption occurs beforehead, we should handle them first. Once `R` is
+ ended, another call of `R.end_exp` will not take effect.
+ """
+ signal.signal(signal.SIGINT, experiment_kill_signal_handler) # handle user keyboard interupt
+ sys.excepthook = experiment_exception_hook # handle uncaught exception
+ atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends
+
+
+def experiment_exception_hook(type, value, tb):
+ """
+ End an experiment with status to be "FAILED". This exception tries to catch those uncaught exception
+ and end the experiment automatically.
+
+ Parameters
+ type: Exception type
+ value: Exception's value
+ tb: Exception's traceback
+ """
+ logger.error(f"An exception has been raised[{type.__name__}: {value}].")
+
+ # Same as original format
+ traceback.print_tb(tb)
+ print(f"{type.__name__}: {value}")
+
+ R.end_exp(recorder_status=Recorder.STATUS_FA)
+
+
+def experiment_kill_signal_handler(signum, frame):
+ """
+ End an experiment when user kill the program through keyboard (CTRL+C, etc.).
+ """
+ R.end_exp(recorder_status=Recorder.STATUS_FA)
diff --git a/requirements.txt b/requirements.txt
index 165619920..638ce22f4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,3 +22,4 @@ scikit_learn==0.23.2
torch==1.6.0
tqdm==4.49.0
yahooquery==2.2.7
+mlflow==1.12.1
\ No newline at end of file
diff --git a/scripts/README.md b/scripts/README.md
index 98b01e0c3..99af4a457 100644
--- a/scripts/README.md
+++ b/scripts/README.md
@@ -20,7 +20,6 @@ python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
### Downlaod US Data
-> The US stock code contains 'PRN', and the directory cannot be created on Windows system: https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
```bash
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
@@ -44,6 +43,8 @@ python get_data.py qlib_data --help
### US data
+> Need to download data first: [Downlaod US Data](#Downlaod-US-Data)
+
```python
import qlib
from qlib.config import REG_US
@@ -53,6 +54,8 @@ qlib.init(provider_uri=provider_uri, region=REG_US)
### CN data
+> Need to download data first: [Download CN Data](#Download-CN-Data)
+
```python
import qlib
from qlib.config import REG_CN
diff --git a/scripts/check_dump_bin.py b/scripts/check_dump_bin.py
index 7c2ceccda..7c2e837af 100644
--- a/scripts/check_dump_bin.py
+++ b/scripts/check_dump_bin.py
@@ -108,9 +108,7 @@ class CheckBin:
return self.COMPARE_ERROR
def check(self):
- """Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data
-
- """
+ """Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data"""
logger.info("start check......")
error_list = []
diff --git a/scripts/data_collector/index.py b/scripts/data_collector/index.py
index c5f3854fd..300e6b625 100644
--- a/scripts/data_collector/index.py
+++ b/scripts/data_collector/index.py
@@ -24,6 +24,7 @@ class IndexBase:
INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD]
REMOVE = "remove"
ADD = "add"
+ INST_PREFIX = ""
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
"""
@@ -196,7 +197,11 @@ class IndexBase:
_tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns)
new_df = new_df.append(_tmp_df, sort=False)
- new_df.loc[:, instruments_columns].to_csv(
+ inst_df = new_df.loc[:, instruments_columns]
+ _inst_prefix = self.INST_PREFIX.strip()
+ if _inst_prefix:
+ inst_df["save_inst"] = inst_df[self.SYMBOL_FIELD_NAME].apply(lambda x: f"{_inst_prefix}{x}")
+ inst_df.to_csv(
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
)
logger.info(f"parse {self.index_name.lower()} companies finished.")
diff --git a/scripts/data_collector/us_index/collector.py b/scripts/data_collector/us_index/collector.py
index ea1e974a0..0641437e0 100644
--- a/scripts/data_collector/us_index/collector.py
+++ b/scripts/data_collector/us_index/collector.py
@@ -33,6 +33,10 @@ WIKI_INDEX_NAME_MAP = {
class WIKIIndex(IndexBase):
+ # NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix
+ # https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
+ INST_PREFIX = "_"
+
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
super(WIKIIndex, self).__init__(
index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep
diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py
index 855569642..2cf9f4c6a 100644
--- a/scripts/data_collector/utils.py
+++ b/scripts/data_collector/utils.py
@@ -184,9 +184,14 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
names=["symbol", "start_date", "end_date"],
)
_all_symbols += ins_df["symbol"].unique().tolist()
- _US_SYMBOLS = sorted(
- set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))
- )
+
+ def _format(s_):
+ s_ = s_.replace(".", "-")
+ s_ = s_.strip("$")
+ s_ = s_.strip("*")
+ return s_
+
+ _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))))
return _US_SYMBOLS
diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py
index 69c7f8f15..0d41251f1 100644
--- a/scripts/data_collector/yahoo/collector.py
+++ b/scripts/data_collector/yahoo/collector.py
@@ -44,6 +44,7 @@ class YahooCollector:
delay=0,
check_data_length: bool = False,
limit_nums: int = None,
+ show_1m_logging: bool = False,
):
"""
@@ -67,10 +68,13 @@ class YahooCollector:
check data length, by default False
limit_nums: int
using for debug, by default None
+ show_1m_logging: bool
+ show 1m logging, by default False; if True, there may be many warning logs
"""
self.save_dir = Path(save_dir).expanduser().resolve()
self.save_dir.mkdir(parents=True, exist_ok=True)
self._delay = delay
+ self._show_1m_logging = show_1m_logging
self.stock_list = sorted(set(self.get_stock_list()))
if limit_nums is not None:
try:
@@ -83,7 +87,7 @@ class YahooCollector:
self._interval = interval
self._check_small_data = check_data_length
self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
- self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME
+ self._end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
if self._interval == "1m":
self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME)
elif self._interval == "1d":
@@ -91,8 +95,12 @@ class YahooCollector:
else:
raise ValueError(f"interval error: {self._interval}")
+ # using for 1m
+ self._next_datetime = self.convert_datetime(self._start_datetime.date() + pd.Timedelta(days=1))
+ self._latest_datetime = self.convert_datetime(self._end_datetime.date())
+
self._start_datetime = self.convert_datetime(self._start_datetime)
- self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME))
+ self._end_datetime = self.convert_datetime(self._end_datetime)
@property
@abc.abstractmethod
@@ -100,20 +108,24 @@ class YahooCollector:
# daily, one year: 252 / 4
# us 1min, a week: 6.5 * 60 * 5
# cn 1min, a week: 4 * 60 * 5
- raise NotImplementedError("rewirte min_numbers_trading")
+ raise NotImplementedError("rewrite min_numbers_trading")
@abc.abstractmethod
def get_stock_list(self):
- raise NotImplementedError("rewirte get_stock_list")
+ raise NotImplementedError("rewrite get_stock_list")
@property
- @abc.abstractclassmethod
+ @abc.abstractmethod
def _timezone(self):
raise NotImplementedError("rewrite get_timezone")
- def convert_datetime(self, dt: pd.Timestamp):
- dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
- return pd.Timestamp(dt, tz=tzlocal(), unit="s")
+ def convert_datetime(self, dt: [pd.Timestamp, datetime.date, str]):
+ try:
+ dt = pd.Timestamp(dt, tz=self._timezone).timestamp()
+ dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
+ except ValueError as e:
+ pass
+ return dt
def _sleep(self):
time.sleep(self._delay)
@@ -136,7 +148,7 @@ class YahooCollector:
df["symbol"] = symbol
if stock_path.exists():
with stock_path.open("a") as fp:
- df.to_csv(fp, index=False, header=None)
+ df.to_csv(fp, index=False, header=False)
else:
with stock_path.open("w") as fp:
df.to_csv(fp, index=False)
@@ -155,34 +167,47 @@ class YahooCollector:
def _get_from_remote(self, symbol):
def _get_simple(start_, end_):
self._sleep()
+ error_msg = f"{symbol}-{self._interval}-{start_}-{end_}"
+
+ def _show_logging_func():
+ if self._interval == "1m" and self._show_1m_logging:
+ logger.warning(f"{error_msg}:{_resp}")
+
try:
_resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_)
if isinstance(_resp, pd.DataFrame):
return _resp.reset_index()
+ elif isinstance(_resp, dict):
+ _temp_data = _resp.get(symbol, {})
+ if isinstance(_temp_data, str) or (
+ isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
+ ):
+ _show_logging_func()
else:
- logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}")
+ _show_logging_func()
except Exception as e:
- logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}")
+ logger.warning(f"{error_msg}:{e}")
_result = None
if self._interval == "1d":
_result = _get_simple(self._start_datetime, self._end_datetime)
elif self._interval == "1m":
- _start_date = self._start_datetime.date() + pd.Timedelta(days=1)
- _end_date = self._end_datetime.date()
- if _start_date >= _end_date:
+ if self._next_datetime >= self._latest_datetime:
_result = _get_simple(self._start_datetime, self._end_datetime)
else:
_res = []
def _get_multi(start_, end_):
_resp = _get_simple(start_, end_)
- if _resp is not None:
+ if _resp is not None and not _resp.empty:
_res.append(_resp)
- for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)):
+ for _s, _e in (
+ (self._start_datetime, self._next_datetime),
+ (self._latest_datetime, self._end_datetime),
+ ):
_get_multi(_s, _e)
- for _start in pd.date_range(_start_date, _end_date, closed="left"):
+ for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
_end = _start + pd.Timedelta(days=1)
self._sleep()
_get_multi(_start, _end)
@@ -472,6 +497,7 @@ class Run:
interval="1d",
check_data_length=False,
limit_nums=None,
+ show_1m_logging=False,
):
"""download data from Internet
@@ -491,6 +517,9 @@ class Run:
check data length, by default False
limit_nums: int
using for debug, by default None
+ show_1m_logging: bool
+ show 1m logging, by default False; if True, there may be many warning logs
+
Examples
---------
# get daily data
@@ -510,6 +539,7 @@ class Run:
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
+ show_1m_logging=show_1m_logging,
).collector_data()
def normalize_data(self):
@@ -531,6 +561,7 @@ class Run:
interval="1d",
check_data_length=False,
limit_nums=None,
+ show_1m_logging=False,
):
"""download -> normalize
@@ -550,6 +581,9 @@ class Run:
check data length, by default False
limit_nums: int
using for debug, by default None
+ show_1m_logging: bool
+ show 1m logging, by default False; if True, there may be many warning logs
+
Examples
-------
python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
@@ -562,6 +596,7 @@ class Run:
interval=interval,
check_data_length=check_data_length,
limit_nums=limit_nums,
+ show_1m_logging=show_1m_logging,
)
self.normalize_data()
diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py
index 2e44c454e..bdc227029 100644
--- a/scripts/dump_bin.py
+++ b/scripts/dump_bin.py
@@ -27,6 +27,7 @@ class DumpDataBase:
HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S"
INSTRUMENTS_SEP = "\t"
INSTRUMENTS_FILE_NAME = "all.txt"
+ SAVE_INST_FIELD = "save_inst"
UPDATE_MODE = "update"
ALL_MODE = "all"
@@ -44,6 +45,7 @@ class DumpDataBase:
exclude_fields: str = "",
include_fields: str = "",
limit_nums: int = None,
+ inst_prefix: str = "",
):
"""
@@ -71,6 +73,9 @@ class DumpDataBase:
fields not dumped
limit_nums: int
Use when debugging, default None
+ inst_prefix: str
+ add a column to the instruments file and record the saved instrument name,
+ the US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix.
"""
csv_path = Path(csv_path).expanduser()
if isinstance(exclude_fields, str):
@@ -79,6 +84,7 @@ class DumpDataBase:
include_fields = include_fields.split(",")
self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)))
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
+ self._inst_prefix = inst_prefix.strip()
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
@@ -134,7 +140,7 @@ class DumpDataBase:
def _get_source_data(self, file_path: Path) -> pd.DataFrame:
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
- df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64)
+ df[self.date_field_name] = df[self.date_field_name].astype(str).astype(np.datetime64)
# df.drop_duplicates([self.date_field_name], inplace=True)
return df
@@ -160,12 +166,19 @@ class DumpDataBase:
)
def _read_instruments(self, instrument_path: Path) -> pd.DataFrame:
- return pd.read_csv(
+ df = pd.read_csv(
instrument_path,
sep=self.INSTRUMENTS_SEP,
- names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD],
+ names=[
+ self.symbol_field_name,
+ self.INSTRUMENTS_START_FIELD,
+ self.INSTRUMENTS_END_FIELD,
+ self.SAVE_INST_FIELD,
+ ],
)
+ return df
+
def save_calendars(self, calendars_data: list):
self._calendars_dir.mkdir(parents=True, exist_ok=True)
calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve())
@@ -176,7 +189,13 @@ class DumpDataBase:
self._instruments_dir.mkdir(parents=True, exist_ok=True)
instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve())
if isinstance(instruments_data, pd.DataFrame):
- instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]]
+ _df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]
+ if self._inst_prefix:
+ _df_fields.append(self.SAVE_INST_FIELD)
+ instruments_data[self.SAVE_INST_FIELD] = instruments_data[self.symbol_field_name].apply(
+ lambda x: f"{self._inst_prefix}{x}"
+ )
+ instruments_data = instruments_data.loc[:, _df_fields]
instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP)
else:
np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8")
@@ -234,6 +253,7 @@ class DumpDataBase:
logger.warning(f"{code} data is None or empty")
return
# features save dir
+ code = self._inst_prefix + code if self._inst_prefix else code
features_dir = self._features_dir.joinpath(code)
features_dir.mkdir(parents=True, exist_ok=True)
self._data_to_bin(df, calendar_list, features_dir)
@@ -262,7 +282,10 @@ class DumpDataAll(DumpDataBase):
_begin_time = self._format_datetime(_begin_time)
_end_time = self._format_datetime(_end_time)
symbol = self.get_symbol_from_file(file_path)
- date_range_list.append(f"{self.INSTRUMENTS_SEP.join((symbol.upper(), _begin_time, _end_time))}")
+ _inst_fields = [symbol.upper(), _begin_time, _end_time]
+ if self._inst_prefix:
+ _inst_fields.append(self._inst_prefix + symbol.upper())
+ date_range_list.append(f"{self.INSTRUMENTS_SEP.join(_inst_fields)}")
p_bar.update()
self._kwargs["all_datetime_set"] = all_datetime
self._kwargs["date_range_list"] = date_range_list
@@ -310,16 +333,18 @@ class DumpDataFix(DumpDataAll):
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
p_bar.update()
- self.save_instruments(pd.DataFrame.from_dict(self._old_instruments, orient="index"))
+ _inst_df = pd.DataFrame.from_dict(self._old_instruments, orient="index")
+ _inst_df.index.names = [self.symbol_field_name]
+ self.save_instruments(_inst_df.reset_index())
logger.info("end of instruments dump.\n")
def dump(self):
self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt"))
# noinspection PyAttributeOutsideInit
- self._old_instruments = self._read_instruments(
- self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME)
- ).to_dict(
- orient="index"
+ self._old_instruments = (
+ self._read_instruments(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME))
+ .set_index([self.symbol_field_name])
+ .to_dict(orient="index")
) # type: dict
self._dump_instruments()
self._dump_features()
diff --git a/scripts/get_data.py b/scripts/get_data.py
index 661e31c5f..f4dba1474 100644
--- a/scripts/get_data.py
+++ b/scripts/get_data.py
@@ -55,7 +55,9 @@ class GetData:
for _file in tqdm(zp.namelist()):
zp.extract(_file, str(target_dir.resolve()))
- def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"):
+ def qlib_data(
+ self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"
+ ):
"""download cn qlib data from remote
Parameters
@@ -77,9 +79,6 @@ class GetData:
-------
"""
- # TODO: The US stock code contains "PRN", and the directory cannot be created on Windows system
- if region.lower() == "us":
- logger.warning(f"The US stock code contains 'PRN', and the directory cannot be created on Windows system")
file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip"
self._download_data(file_name.lower(), target_dir)
diff --git a/setup.py b/setup.py
index 3a6237e5a..0696a766f 100644
--- a/setup.py
+++ b/setup.py
@@ -12,7 +12,7 @@ from setuptools import find_packages, setup, Extension
NAME = "pyqlib"
DESCRIPTION = "A Quantitative-research Platform"
REQUIRES_PYTHON = ">=3.5.0"
-VERSION = "0.5.1.dev0"
+VERSION = "0.6.0.alpha"
# Detect Cython
try:
@@ -43,17 +43,21 @@ REQUIRED = [
"schedule>=0.6.0",
"cvxpy==1.0.21",
"hyperopt==0.1.1",
- "fire>=0.2.1",
+ "fire>=0.3.1",
"statsmodels",
"xlrd>=1.0.0",
- "plotly==3.5.0",
+ "plotly==4.12.0",
"matplotlib==3.1.3",
"tables>=3.6.1",
"pyyaml>=5.3.1",
+ "mlflow>=1.12.1",
"tqdm",
"loguru",
"lightgbm",
"tornado",
+ "joblib>=0.17.0",
+ "fire>=0.3.1",
+ "ruamel.yaml>=0.16.12",
]
# Numpy include
@@ -97,7 +101,7 @@ setup(
entry_points={
# 'console_scripts': ['mycli=mymodule:cli'],
"console_scripts": [
- "estimator=qlib.contrib.estimator.launcher:run",
+ "qrun=qlib.workflow.cli:run",
],
},
ext_modules=extensions,
diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py
index 886fb31f3..befd296b0 100644
--- a/tests/test_all_pipeline.py
+++ b/tests/test_all_pipeline.py
@@ -2,68 +2,95 @@
# Licensed under the MIT License.
import sys
+import shutil
import unittest
from pathlib import Path
import numpy as np
import pandas as pd
-from scipy.stats import pearsonr
import qlib
-from qlib.config import REG_CN
+from qlib.config import REG_CN, C
from qlib.utils import drop_nan_by_y_index
from qlib.contrib.model.gbdt import LGBModel
-from qlib.contrib.estimator.handler import Alpha158
+from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
-from qlib.utils import exists_qlib_data
+from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
+from qlib.workflow import R
+from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
-DATA_HANDLER_CONFIG = {
- "dropna_label": True,
- "start_date": "2008-01-01",
- "end_date": "2020-08-01",
- "market": "CSI300",
+market = "csi300"
+benchmark = "SH000300"
+
+###################################
+# train model
+###################################
+data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": market,
}
-MODEL_CONFIG = {
- "loss": "mse",
- "colsample_bytree": 0.8879,
- "learning_rate": 0.0421,
- "subsample": 0.8789,
- "lambda_l1": 205.6999,
- "lambda_l2": 580.9768,
- "max_depth": 8,
- "num_leaves": 210,
- "num_threads": 20,
+task = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ "kwargs": {
+ "loss": "mse",
+ "colsample_bytree": 0.8879,
+ "learning_rate": 0.0421,
+ "subsample": 0.8789,
+ "lambda_l1": 205.6999,
+ "lambda_l2": 580.9768,
+ "max_depth": 8,
+ "num_leaves": 210,
+ "num_threads": 20,
+ },
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+ },
}
-TRAINER_CONFIG = {
- "train_start_date": "2008-01-01",
- "train_end_date": "2014-12-31",
- "validate_start_date": "2015-01-01",
- "validate_end_date": "2016-12-31",
- "test_start_date": "2017-01-01",
- "test_end_date": "2020-08-01",
-}
-
-STRATEGY_CONFIG = {
- "topk": 50,
- "n_drop": 5,
-}
-
-BACKTEST_CONFIG = {
- "verbose": False,
- "limit_threshold": 0.095,
- "account": 100000000,
- "benchmark": "SH000300",
- "deal_price": "close",
- "open_cost": 0.0005,
- "close_cost": 0.0015,
- "min_cost": 5,
+port_analysis_config = {
+ "strategy": {
+ "class": "TopkDropoutStrategy",
+ "module_path": "qlib.contrib.strategy.strategy",
+ "kwargs": {
+ "topk": 50,
+ "n_drop": 5,
+ },
+ },
+ "backtest": {
+ "verbose": False,
+ "limit_threshold": 0.095,
+ "account": 100000000,
+ "benchmark": benchmark,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ },
}
@@ -78,59 +105,53 @@ def train():
performance: dict
model performance
"""
- # get data
- x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(**DATA_HANDLER_CONFIG).get_split_data(
- **TRAINER_CONFIG
- )
- # train
- model = LGBModel(**MODEL_CONFIG)
- model.fit(x_train, y_train, x_validate, y_validate)
- _pred = model.predict(x_test)
- _pred = pd.DataFrame(_pred, index=x_test.index, columns=y_test.columns)
- pred_score = pd.DataFrame(index=_pred.index)
- pred_score["score"] = _pred.iloc(axis=1)[0]
+ # model initiaiton
+ model = init_instance_by_config(task["model"])
+ dataset = init_instance_by_config(task["dataset"])
- # get performance
- try:
- model_score = model.score(x_test, y_test)
- except NotImplementedError:
- model_score = None
- # Remove rows from x, y and w, which contain Nan in any columns in y_test.
- x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
- pred_test = model.predict(x_test)
- model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
+ # start exp
+ with R.start(experiment_name="workflow"):
+ R.log_params(**flatten_dict(task))
+ model.fit(dataset)
- return pred_score, {"model_score": model_score, "model_pearsonr": model_pearsonr}
+ # prediction
+ recorder = R.get_recorder()
+ rid = recorder.id
+ sr = SignalRecord(model, dataset, recorder)
+ sr.generate()
+ pred_score = sr.load()
+
+ # calculate ic and ric
+ sar = SigAnaRecord(recorder)
+ sar.generate()
+ ic = sar.load(sar.get_path("ic.pkl"))
+ ric = sar.load(sar.get_path("ric.pkl"))
+
+ return pred_score, {"ic": ic, "ric": ric}, rid
-def backtest(pred):
- """backtest
+def backtest_analysis(pred, rid):
+ """backtest and analysis
Parameters
----------
- pred: pandas.DataFrame
+ pred : pandas.DataFrame
predict scores
+ rid : str
+ the id of the recorder to be used in this function
Returns
-------
- report_normal: pandas.DataFrame
-
- positions_normal: dict
+ analysis : pandas.DataFrame
+ the analysis result
"""
- strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
- _report_normal, _positions_normal = normal_backtest(pred, strategy=strategy, **BACKTEST_CONFIG)
- return _report_normal, _positions_normal
-
-
-def analyze(report_normal):
- _analysis = dict()
- _analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- _analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(_analysis) # type: pd.DataFrame
+ recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
+ # backtest
+ par = PortAnaRecord(recorder, port_analysis_config)
+ par.generate()
+ analysis_df = par.load(par.get_path("port_analysis.pkl"))
print(analysis_df)
return analysis_df
@@ -139,6 +160,7 @@ class TestAllFlow(unittest.TestCase):
PRED_SCORE = None
REPORT_NORMAL = None
POSITIONS = None
+ RID = None
@classmethod
def setUpClass(cls) -> None:
@@ -149,16 +171,22 @@ class TestAllFlow(unittest.TestCase):
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData
- GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri)
+ GetData().qlib_data(
+ name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri
+ )
qlib.init(provider_uri=provider_uri, region=REG_CN)
+ @classmethod
+ def tearDownClass(cls) -> None:
+ shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
+
def test_0_train(self):
- TestAllFlow.PRED_SCORE, model_pearsonr = train()
- self.assertGreaterEqual(model_pearsonr["model_pearsonr"], 0, "train failed")
+ TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
+ self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
+ self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
def test_1_backtest(self):
- TestAllFlow.REPORT_NORMAL, TestAllFlow.POSITIONS = backtest(TestAllFlow.PRED_SCORE)
- analyze_df = analyze(TestAllFlow.REPORT_NORMAL)
+ analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
self.assertGreaterEqual(
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
0.10,
diff --git a/tests/test_dump_data.py b/tests/test_dump_data.py
index 01e6a3758..dfa7f8556 100644
--- a/tests/test_dump_data.py
+++ b/tests/test_dump_data.py
@@ -75,7 +75,9 @@ class TestDumpData(unittest.TestCase):
def test_4_dump_features_simple(self):
stock = self.STOCK_NAMES[0]
- dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS)
+ dump_data = DumpDataFix(
+ csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS
+ )
dump_data.dump()
df = D.features([stock], self.QLIB_FIELDS)
]