diff --git a/README.md b/README.md
index 4383dea26..c890afaca 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
-
+
@@ -28,6 +28,8 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
- [Quant Model Zoo](#quant-model-zoo)
+ - [Run a single model](#run-a-single-model)
+ - [Run multiple models](#run-multiple-models)
- [Quant Dataset Zoo](#quant-dataset-zoo)
- [More About Qlib](#more-about-qlib)
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
@@ -39,19 +41,17 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
# Framework of Qlib
-

+
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
-| Name | Description |
-| ------ | ----- |
-| `Data layer` | `DataServer` focuses on providing high-performance infrastructure for users to manage and retrieve raw data. `DataEnhancement` will preprocess the data and provide the best dataset to be fed into the models. |
-| `Interday Model` | `Interday model` focuses on producing prediction scores (aka. _alpha_). Models are trained by `Model Creator` and managed by `Model Manager`. Users could choose one or multiple models for prediction. Multiple models could be combined with `Ensemble` module. |
-| `Interday Strategy` | `Portfolio Generator` will take prediction scores as input and output the orders based on the current position to achieve the target portfolio. |
-| `Intraday Trading` | `Order Executor` is responsible for executing orders output by `Interday Strategy` and returning the executed results. |
-| `Analysis` | Users could get a detailed analysis report of forecasting signals and portfolios in this part. |
+| Name | Description |
+| ------ | ----- |
+| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
+| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
+| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
* The modules with hand-drawn style are under development and will be released in the future.
* The modules with dashed borders are highly user-customizable and extendible.
@@ -139,17 +139,20 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
```bash
- risk
- excess_return_without_cost mean 0.000675
- std 0.005456
- annualized_return 0.170077
- information_ratio 1.963824
- max_drawdown -0.063646
- excess_return_with_cost mean 0.000479
- std 0.005453
- annualized_return 0.120776
- information_ratio 1.395116
- max_drawdown -0.071216
+ 'The following are analysis results of the excess return without cost.'
+ risk
+ mean 0.000708
+ std 0.005626
+ annualized_return 0.178316
+ information_ratio 1.996555
+ max_drawdown -0.081806
+ 'The following are analysis results of the excess return with cost.'
+ risk
+ mean 0.000512
+ std 0.005626
+ annualized_return 0.128982
+ information_ratio 1.444287
+ max_drawdown -0.091078
@@ -159,19 +162,19 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports
- Forecasting signal (model prediction) analysis
- Cumulative Return of groups
- 
+ 
- Return distribution
- 
+ 
- Information Coefficient (IC)
- 
- 
- 
+ 
+ 
+ 
- Auto Correlation of forecasting signal (model prediction)
- 
+ 
- Portfolio analysis
- Backtest return
- 
+ 
+- [HATs based on pytorch](qlib/contrib/model/pytorch_hats.py)
+- [TFT based on tensorflow](examples/benchmarks/TFT/tft.py)
Your PR of new Quant models is highly welcomed.
diff --git a/docs/_static/img/analysis/analysis_model_IC.png b/docs/_static/img/analysis/analysis_model_IC.png
index 0064fb890..26b4b4bfa 100644
Binary files a/docs/_static/img/analysis/analysis_model_IC.png and b/docs/_static/img/analysis/analysis_model_IC.png differ
diff --git a/docs/_static/img/analysis/analysis_model_NDQ.png b/docs/_static/img/analysis/analysis_model_NDQ.png
index c1824368b..5197c4b03 100644
Binary files a/docs/_static/img/analysis/analysis_model_NDQ.png and b/docs/_static/img/analysis/analysis_model_NDQ.png differ
diff --git a/docs/_static/img/analysis/analysis_model_auto_correlation.png b/docs/_static/img/analysis/analysis_model_auto_correlation.png
index 3f213a79b..ab9e30165 100644
Binary files a/docs/_static/img/analysis/analysis_model_auto_correlation.png and b/docs/_static/img/analysis/analysis_model_auto_correlation.png differ
diff --git a/docs/_static/img/analysis/analysis_model_cumulative_return.png b/docs/_static/img/analysis/analysis_model_cumulative_return.png
index bcccf138a..c305a42b4 100644
Binary files a/docs/_static/img/analysis/analysis_model_cumulative_return.png and b/docs/_static/img/analysis/analysis_model_cumulative_return.png differ
diff --git a/docs/_static/img/analysis/analysis_model_long_short.png b/docs/_static/img/analysis/analysis_model_long_short.png
index 2fcb08c4e..5efed2d6c 100644
Binary files a/docs/_static/img/analysis/analysis_model_long_short.png and b/docs/_static/img/analysis/analysis_model_long_short.png differ
diff --git a/docs/_static/img/analysis/analysis_model_monthly_IC.png b/docs/_static/img/analysis/analysis_model_monthly_IC.png
index 0056c6c9c..8443f3860 100644
Binary files a/docs/_static/img/analysis/analysis_model_monthly_IC.png and b/docs/_static/img/analysis/analysis_model_monthly_IC.png differ
diff --git a/docs/_static/img/analysis/report.png b/docs/_static/img/analysis/report.png
index dfd227f5a..2901da603 100644
Binary files a/docs/_static/img/analysis/report.png and b/docs/_static/img/analysis/report.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_annualized_return.png b/docs/_static/img/analysis/risk_analysis_annualized_return.png
index 1979ca19b..18e7a90aa 100644
Binary files a/docs/_static/img/analysis/risk_analysis_annualized_return.png and b/docs/_static/img/analysis/risk_analysis_annualized_return.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_bar.png b/docs/_static/img/analysis/risk_analysis_bar.png
index 1cce1f340..c90650a6d 100644
Binary files a/docs/_static/img/analysis/risk_analysis_bar.png and b/docs/_static/img/analysis/risk_analysis_bar.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_information_ratio.png b/docs/_static/img/analysis/risk_analysis_information_ratio.png
index edc64b17d..7028eaf02 100644
Binary files a/docs/_static/img/analysis/risk_analysis_information_ratio.png and b/docs/_static/img/analysis/risk_analysis_information_ratio.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_max_drawdown.png b/docs/_static/img/analysis/risk_analysis_max_drawdown.png
index a68810222..b7f1ae130 100644
Binary files a/docs/_static/img/analysis/risk_analysis_max_drawdown.png and b/docs/_static/img/analysis/risk_analysis_max_drawdown.png differ
diff --git a/docs/_static/img/analysis/risk_analysis_std.png b/docs/_static/img/analysis/risk_analysis_std.png
index 73d782e20..6f38def26 100644
Binary files a/docs/_static/img/analysis/risk_analysis_std.png and b/docs/_static/img/analysis/risk_analysis_std.png differ
diff --git a/docs/_static/img/analysis/score_ic.png b/docs/_static/img/analysis/score_ic.png
index 6e1d37d2a..a5739a9ba 100644
Binary files a/docs/_static/img/analysis/score_ic.png and b/docs/_static/img/analysis/score_ic.png differ
diff --git a/docs/_static/img/framework.png b/docs/_static/img/framework.png
index 673f10e03..d8242f7c1 100644
Binary files a/docs/_static/img/framework.png and b/docs/_static/img/framework.png differ
diff --git a/docs/component/data.rst b/docs/component/data.rst
index 9ef71a6cb..aa01fe226 100644
--- a/docs/component/data.rst
+++ b/docs/component/data.rst
@@ -1,7 +1,7 @@
.. _data:
================================
-Data Layer: Data Framework&Usage
+Data Layer: Data Framework & Usage
================================
Introduction
@@ -15,7 +15,9 @@ The introduction of ``Data Layer`` includes the following parts.
- Data Preparation
- Data API
+- Data Loader
- Data Handler
+- Dataset
- Cache
- Data and Cache File Structure
@@ -31,13 +33,19 @@ Such data will be stored with filename suffix `.bin` (We'll call them `.bin` fil
Qlib Format Dataset
--------------------
-``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the dataset as follows.
+``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
-After running the above command, users can find china-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory.
+In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, which can be downloaded with the following command:
+
+.. code-block:: bash
+
+ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
+
+After running the above command, users can find china-stock and us-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
@@ -49,12 +57,45 @@ Converting CSV Format into Qlib Format
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert data in CSV format into `.bin` files (Qlib format).
-Users can download the china-stock data in CSV format as follows for reference to the CSV format.
+Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
.. code-block:: bash
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
+Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
+
+- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
+
+ - Name the CSV file after a stock: `SH600000.csv`, `AAPL.csv` (not case sensitive).
+
+ - CSV file includes a column of the stock name. User **must** specify the column name when dumping the data. Here is an example:
+
+ .. code-block:: bash
+
+ python scripts/dump_bin.py dump_all ... --symbol_field_name symbol
+
+ where the data are in the following format:
+
+ .. code-block::
+
+ symbol,close
+ SH600000,120
+
+- CSV file **must** includes a column for the date, and when dumping the data, user must specify the date column name. Here is an example:
+
+ .. code-block:: bash
+
+ python scripts/dump_bin.py dump_all ... --date_field_name date
+
+ where the data are in the following format:
+
+ .. code-block::
+
+ symbol,date,close,open,volume
+ SH600000,2020-11-01,120,121,12300000
+ SH600000,2020-11-02,123,120,12300000
+
Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv_data/my_data``, they can run the following command to start the conversion.
@@ -62,6 +103,12 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv
python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
+For other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:
+
+.. code-block:: bash
+
+ python dump_bin.py dump_all --help
+
After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`.
.. note::
@@ -97,9 +144,8 @@ China-Stock Mode & US-Stock Mode
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=REG_CN)
-- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` does not provide a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
- - Prepare data in CSV format
- - Convert data from CSV format to Qlib format, please refer to section `Converting CSV Format into Qlib Format <#converting-csv-format-into-qlib-format>`_.
+- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
+ - Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
- Initialize ``Qlib`` in US-stock mode
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
@@ -141,68 +187,97 @@ Filter
Expression dynamic instrument filter. Filter the instruments based on a certain expression. An expression rule indicating a certain feature field is required.
- `basic features filter`: rule_expression = '$close/$open>5'
- - `cross-sectional features filter` : rule_expression = '$rank($close)<10'
+ - `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
-
Reference
-------------
To know more about ``Data API``, please refer to `Data API <../reference/api.html#data>`_.
+
+Data Loader
+=================
+
+``Data Loader`` in ``Qlib`` is designed to load raw data from the original data source. It will be loaded and used in the ``Data Handler`` module.
+
+QlibDataLoader
+---------------
+
+The ``QlibDataLoader`` class in ``Qlib`` is such an interface that allows users to load raw data from the data source.
+
+Interface
+------------
+
+Here are some interfaces of the ``QlibDataLoader`` class:
+
+.. autoclass:: qlib.data.dataset.loader.QlibDataLoader
+ :members: load, load_group_df
+
+API
+-----------
+
+To know more about ``Data Loader``, please refer to `Data Loader API <../reference/api.html#module-qlib.data.dataset.loader>`_.
+
+
Data Handler
=================
-Users can use ``Data Handler`` in an automatic workflow by ``Estimator``, refer to `Estimator: Workflow Management `_ for more details.
+The ``Data Handler`` module in ``Qlib`` is designed to handler those common data processing methods which will be used by most of the models.
-Also, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data(standardization, remove NaN, etc.) and build datasets. It is a subclass of ``qlib.data.dataset.handler.DataHandlerLP``, which provides some interfaces as follows.
+Users can use ``Data Handler`` in an automatic workflow by ``qrun``, refer to `Workflow: Workflow Management `_ for more details.
-Base Class & Interface
+DataHandlerLP
+--------------
+
+In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
+
+In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some leanable ``Processors`` which can learn the parameters of data processing. When new data comes in, these `trained` ``Processors`` can then infer on the new data and thus processing real-time data in an efficient way. More information about ``Processors`` will be listed in the next subsection.
+
+
+Interface
----------------------
-Qlib provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_, which provides the following interfaces:
+Here are some important interfaces that ``DataHandlerLP`` provides:
-- `load_feature`
- Implement the interface to load the data features.
-
-- `load_label`
- Implement the interface to load the data labels and calculate the users' labels.
-
-- `setup_processed_data`
- Implement the interface for data preprocessing, such as preparing feature columns, discarding blank lines, and so on.
-
-Qlib also provides two functions to help users init the data handler, users can override them for users' needs.
-
-- `_init_raw_data`
- Users can init the raw df, feature names, and label names of data handler in this function.
- If the index of feature df and label df are not the same, users need to override this method to merge them (e.g. inner, left, right merge).
+.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
+ :members: __init__, fetch, get_cols
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
+
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
-Usage
---------------
+Processor
+----------
-``Data Handler`` can be used as a single module, which provides the following mehtods:
+The ``Processor`` module in ``Qlib`` is designed to be learnable and it is responsible for handling data processing such as `normalization` and `drop none/nan features/labels`.
-- `get_split_data`
- - According to the start and end dates, return features and labels of the pandas DataFrame type used for the 'Model'
-
-- `get_rolling_data`
- - According to the start and end dates, and `rolling_period`, an iterator is returned, which can be used to traverse the features and labels used for rolling.
+``Qlib`` provides the following ``Processors``:
+- ``DropnaProcessor``: `processor` that drops N/A features.
+- ``DropnaLabel``: `processor` that drops N/A labels.
+- ``TanhProcess``: `processor` that uses `tanh` to process noise data.
+- ``ProcessInf``: `processor` that handles infinity values, it will be replaces by the mean of the column.
+- ``Fillna``: `processor` that handles N/A values, which will fill the N/A value by 0 or other given number.
+- ``MinMaxNorm``: `processor` that applies min-max normalization.
+- ``ZscoreNorm``: `processor` that applies z-score normalization.
+- ``RobustZScoreNorm``: `processor` that applies robust z-score normalization.
+- ``CSZScoreNorm``: `processor` that applies cross sectional z-score normalization.
+- ``CSRankNorm``: `processor` that applies cross sectional rank normalization.
+Users can also create their own `processor` by inheriting the base class of ``Processor``. Please refer to the implementation of all the processors for more information (`Processor Link `_).
+To know more about ``Processor``, please refer to `Processor API <../reference/api.html#module-qlib.data.dataset.processor>`_.
Example
--------------
-``Data Handler`` can be run with ``estimator`` by modifying the configuration file, and can also be used as a single module.
+``Data Handler`` can be run with ``qrun`` by modifying the configuration file, and can also be used as a single module.
-Know more about how to run ``Data Handler`` with ``Estimator``, please refer to `Estimator: Workflow Management `_
+Know more about how to run ``Data Handler`` with ``qrun``, please refer to `Workflow: Workflow Management `_
Qlib provides implemented data handler `Alpha158`. The following example shows how to run `Alpha158` as a single module.
@@ -211,45 +286,55 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
.. code-block:: Python
+ import qlib
from qlib.contrib.data.handler import Alpha158
- from qlib.contrib.model.gbdt import LGBModel
- DATA_HANDLER_CONFIG = {
- "dropna_label": True,
- "start_date": "2007-01-01",
- "end_date": "2020-08-01",
- "market": "csi300",
+ data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": "csi300",
}
- TRAINER_CONFIG = {
- "train_start_date": "2007-01-01",
- "train_end_date": "2014-12-31",
- "validate_start_date": "2015-01-01",
- "validate_end_date": "2016-12-31",
- "test_start_date": "2017-01-01",
- "test_end_date": "2020-08-01",
- }
+ if __name__ == "__main__":
+ qlib.init()
+ h = Alpha158(**data_handler_config)
- exampleDataHandler = Alpha158(**DATA_HANDLER_CONFIG)
+ # get all the columns of the data
+ print(h.get_cols())
- # example of 'get_split_data'
- x_train, y_train, x_validate, y_validate, x_test, y_test = exampleDataHandler.get_split_data(**TRAINER_CONFIG)
+ # fetch all the labels
+ print(h.fetch(col_set="label"))
- # example of 'get_rolling_data'
-
- for (x_train, y_train, x_validate, y_validate, x_test, y_test) in exampleDataHandler.get_rolling_data(**TRAINER_CONFIG):
- print(x_train, y_train, x_validate, y_validate, x_test, y_test)
-
-
-.. note:: (x_train, y_train, x_validate, y_validate, x_test, y_test) can be used as arguments for the `fit`, `predic``, and `score` methods of the ``Interday Model`` , please refer to `Model `_.
-
-Also, the above example has been given in ``examples.estimator.train_backtest_analyze.ipynb``.
+ # fetch all the features
+ print(h.fetch(col_set="feature"))
API
---------
To know more about ``Data Handler``, please refer to `Data Handler API <../reference/api.html#module-qlib.data.dataset.handler>`_.
+
+Dataset
+=================
+
+The ``Dataset`` module in ``Qlib`` aims to prepare data for model training and inferencing.
+
+The motivation of this module is that we want to maximize the flexibility of of different models to handle data that are suitable for themselves. This module gives the model the rights to process their data in an unique way. For instance, models such as ``GBDT`` may work well on data that contains `nan` or `None` value, while neural networks such as ``DNN`` will break down on such data.
+
+The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most important interface of the class:
+
+.. autoclass:: qlib.data.dataset.__init__.DatasetH
+ :members:
+
+API
+---------
+
+To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
+
+
+
Cache
==========
diff --git a/docs/component/model.rst b/docs/component/model.rst
index 6a6b02f86..b4e341df8 100644
--- a/docs/component/model.rst
+++ b/docs/component/model.rst
@@ -7,7 +7,7 @@ Interday Model: Model Training & Prediction
Introduction
===================
-``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``Estimator``, please refer to `Estimator: Workflow Management `_.
+``Interday Model`` is designed to make the `prediction score` about stocks. Users can use the ``Interday Model`` in an automatic workflow by ``qrun``, please refer to `Workflow: Workflow Management `_.
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Interday Model`` can be used as an independent module also.
@@ -20,151 +20,125 @@ The base class provides the following interfaces:
- `__init__(**kwargs)`
- Initialization.
- - If users use ``Estimator`` to start an `experiment`, the parameter of `__init__` method shoule be consistent with the hyperparameters in the configuration file.
-- `fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs)`
+- `fit(self, dataset, **kwargs)`
- Train model.
- Parameter:
- - `x_train`, pd.DataFrame type, train feature
- The following example explains the value of `x_train`:
+ - `dataset`, ``Qlib``'s ``DatasetH`` type. For more information about ``DatasetH``, users can refer to the related document: `Qlib Dataset <../component/data.html#dataset>`_.
+ The `dataset` is passed into the `model`'s method because there are some unique data preprocessing procedures for each, we want to give each model maximum flexibility to handle the data that is suitable for their own.
+ The following code example shows how to retrieve `x_train`, `y_train` and `w_train` from the `dataset`:
- .. code-block:: YAML
-
- KMID KLEN KMID2 KUP KUP2
- instrument datetime
- SH600004 2012-01-04 0.000000 0.017685 0.000000 0.012862 0.727275
- 2012-01-05 -0.006473 0.025890 -0.250001 0.012945 0.499998
- 2012-01-06 0.008117 0.019481 0.416666 0.008117 0.416666
- 2012-01-09 0.016051 0.025682 0.624998 0.006421 0.250001
- 2012-01-10 0.017323 0.026772 0.647057 0.003150 0.117648
- ... ... ... ... ... ...
- SZ300273 2014-12-25 -0.005295 0.038697 -0.136843 0.016293 0.421052
- 2014-12-26 -0.022486 0.041701 -0.539215 0.002453 0.058824
- 2014-12-29 -0.031526 0.039092 -0.806451 0.000000 0.000000
- 2014-12-30 -0.010000 0.032174 -0.310811 0.013913 0.432433
- 2014-12-31 0.010917 0.020087 0.543479 0.001310 0.065216
+ .. code-block:: Python
-
- `x_train` is a pandas DataFrame, whose index is MultiIndex . Each column of `x_train` corresponds to a feature, and the column name is the feature name.
-
- .. note::
-
- The number and names of the columns are determined by the data handler, please refer to `Data Handler `_ and `Estimator Data Section `_.
-
- - `y_train`, pd.DataFrame type, train label
- The following example explains the value of `y_train`:
+ # get features and labels
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
- .. code-block:: YAML
-
- LABEL
- instrument datetime
- SH600004 2012-01-04 -0.798456
- 2012-01-05 -1.366716
- 2012-01-06 -0.491026
- 2012-01-09 0.296900
- 2012-01-10 0.501426
- ... ...
- SZ300273 2014-12-25 -0.465540
- 2014-12-26 0.233864
- 2014-12-29 0.471368
- 2014-12-30 0.411914
- 2014-12-31 1.342723
-
- `y_train` is a pandas DataFrame, whose index is MultiIndex . The `LABEL` column represents the value of train label.
-
- .. note::
-
- The number and names of the columns are determined by the ``Data Handler``, please refer to `Data Handler `_.
-
- - `x_valid`, pd.DataFrame type, validation feature
- The format of `x_valid` is same as `x_train`
-
-
- - `y_valid`, pd.DataFrame type, validation label
- The format of `y_valid` is same as `y_train`
-
- - `w_train`(Optional args, default is None), pd.DataFrame type, train weight
- `w_train` is a pandas DataFrame, whose shape and index is same as `x_train`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
-
- - `w_train`(Optional args, default is None), pd.DataFrame type, validation weight
- `w_train` is a pandas DataFrame, whose shape and index is the same as `x_valid`. The float value in `w_train` represents the weight of the feature at the same position in `x_train`.
-
-- `predict(self, x_test, **kwargs)`
- - Predict test data 'x_test'
- - Parameter:
- - `x_test`, pd.DataFrame type, test features
- The form of `x_test` is same as `x_train` in 'fit' method.
- - Return:
- - `label`, np.ndarray type, test label
- The label of `x_test` that predicted by model.
-
-- `score(self, x_test, y_test, w_test=None, **kwargs)`
- - Evaluate model with test feature/label
- - Parameter:
- - `x_test`, pd.DataFrame type, test feature
- The format of `x_test` is same as `x_train` in `fit` method.
+ # get weights
+ try:
+ wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
+ w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
+ except KeyError as e:
+ w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
+ w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
- - `x_test`, pd.DataFrame type, test label
- The format of `y_test` is same as `y_train` in `fit` method.
+- `predict(self, dataset, **kwargs)`
+ - Predict test data.
+ - Parameter:
+ - `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
+ - Returns:
+ - Predic results with type: `pandas.Series`.
- - `w_test`, pd.DataFrame type, test weight
- The format of `w_test` is same as `w_train` in `fit` method.
- - Return: float type, evaluation score
+- `finetune(self, dataset, **kwargs)`
+ - Finetune the model.
+ - Parameter:
+ - `dataset`, ``Qlib``'s ``DatasetH`` type. The usage is similar to the example above.
-For other interfaces such as `save`, `load`, `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
+
+For other interfaces such as `finetune`, please refer to `Model API <../reference/api.html#module-qlib.model.base>`_.
Example
==================
-``Qlib`` provides ``LightGBM`` and ``DNN`` models as the baseline, the following steps show how to run`` LightGBM`` as an independent module.
+``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``DNN``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. The following steps show how to run`` LightGBM`` as an independent module.
- Initialize ``Qlib`` with `qlib.init` first, please refer to `Initialization <../start/initialization.html>`_.
- Run the following code to get the `prediction score` `pred_score`
.. code-block:: Python
- from qlib.contrib.data.handler import Alpha158
from qlib.contrib.model.gbdt import LGBModel
+ from qlib.contrib.data.handler import Alpha158
+ from qlib.utils import init_instance_by_config, flatten_dict
+ from qlib.workflow import R
+ from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
- DATA_HANDLER_CONFIG = {
- "dropna_label": True,
- "start_date": "2007-01-01",
- "end_date": "2020-08-01",
- "market": MARKET,
+ market = "csi300"
+ benchmark = "SH000300"
+
+ data_handler_config = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": market,
}
- TRAINER_CONFIG = {
- "train_start_date": "2007-01-01",
- "train_end_date": "2014-12-31",
- "validate_start_date": "2015-01-01",
- "validate_end_date": "2016-12-31",
- "test_start_date": "2017-01-01",
- "test_end_date": "2020-08-01",
+ task = {
+ "model": {
+ "class": "LGBModel",
+ "module_path": "qlib.contrib.model.gbdt",
+ "kwargs": {
+ "loss": "mse",
+ "colsample_bytree": 0.8879,
+ "learning_rate": 0.0421,
+ "subsample": 0.8789,
+ "lambda_l1": 205.6999,
+ "lambda_l2": 580.9768,
+ "max_depth": 8,
+ "num_leaves": 210,
+ "num_threads": 20,
+ },
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "Alpha158",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": data_handler_config,
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+ },
}
+
+ # model initiaiton
+ model = init_instance_by_config(task["model"])
+ dataset = init_instance_by_config(task["dataset"])
- x_train, y_train, x_validate, y_validate, x_test, y_test = Alpha158(
- **DATA_HANDLER_CONFIG
- ).get_split_data(**TRAINER_CONFIG)
+ # start exp
+ with R.start(experiment_name="workflow"):
+ # train
+ R.log_params(**flatten_dict(task))
+ model.fit(dataset)
+ # prediction
+ recorder = R.get_recorder()
+ sr = SignalRecord(model, dataset, recorder)
+ sr.generate()
- MODEL_CONFIG = {
- "loss": "mse",
- "colsample_bytree": 0.8879,
- "learning_rate": 0.0421,
- "subsample": 0.8789,
- "lambda_l1": 205.6999,
- "lambda_l2": 580.9768,
- "max_depth": 8,
- "num_leaves": 210,
- "num_threads": 20,
- }
- # use default model
- model = LGBModel(**MODEL_CONFIG)
- model.fit(x_train, y_train, x_validate, y_validate)
- _pred = model.predict(x_test)
- pred_score = pd.DataFrame(index=_pred.index)
- pred_score["score"] = _pred.iloc(axis=1)[0]
-
- .. note:: `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler `_.
+ .. note::
+
+ `Alpha158` is the data handler provided by ``Qlib``, please refer to `Data Handler `_.
+ `SignalRecord` is the `Record Template` in ``Qlib``, please refer to `Workflow `_.
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
diff --git a/docs/component/recorder.rst b/docs/component/recorder.rst
index efd67e859..4304dcce5 100644
--- a/docs/component/recorder.rst
+++ b/docs/component/recorder.rst
@@ -50,312 +50,17 @@ Qlib Recorder
Here are the available interfaces of ``QlibRecorder``:
-- `__init__(exp_manager)`
- - Initialization.
- - It takes in an input: `exp_manager`, which is an `ExperimentManager` instance. The instance will be created during ``qlib.init``.
-
-- `start(experiment_name=None, recorder_name=None)`
- - High level API to start an experiment. This method can only be called within a Python's '`with`' statement.
- - Parameters:
- - `experiment_name` : str
- name of the experiment one wants to start.
- - `recorder_name` : str
- name of the recorder under the experiment one wants to start.
- - Use case:
-
- .. code-block:: Python
-
- with R.start('test', 'recorder_1'):
- model.fit(dataset)
- R.log...
- ... # further operations
-
-- `start_exp(experiment_name=None, recorder_name=None, uri=None)`
- - Lower level method for starting an experiment. When use this method, one should end the experiment manually and the status of the recorder may not be handled properly.
- - Parameters:
- - `experiment_name` : str
- the name of the experiment to be started
- - `recorder_name` : str
- name of the recorder under the experiment one wants to start.
- - `uri` : str
- the tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
- The default uri are set in the qlib.config.
- - Returns:
- - an experiment instance being started.
- - Use case:
-
- .. code-block:: Python
-
- R.start_exp(experiment_name='test', recorder_name='recorder_1')
- ... # further operations
- R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
-
-- `end_exp(recorder_status=Recorder.STATUS_FI)`
- - Method for ending an experiment manually. It will end the current active experiment, as well as its active recorder with the specified `status` type.
- - Parameters:
- - `status` : str
- The status of a recorder, which can be '`SCHEDULED`', '`RUNNING`', '`FINISHED`', '`FAILED`'.
- - Use case:
-
- .. code-block:: Python
-
- R.start_exp(experiment_name='test')
- ... # further operations
- R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
-
-- `search_records(experiment_ids, **kwargs)`
- - Get a pandas DataFrame of all the records that have been stored with the given search criteria. This method is highly correlated with MLFlow's ``search_runs`` method (`link `_).
- - Parameters:
- - `experiment_ids` : list
- list of experiment IDs.
- - `filter_string` : str
- filter query string, defaults to searching all runs.
- - `run_view_type` : int
- one of enum values ACTIVE_ONLY (1), DELETED_ONLY (2), or ALL (3).
- - `max_results` : int
- the maximum number of runs to put in the dataframe.
- - `order_by` : list
- list of columns to order by (e.g., “metrics.rmse”).
- - Returns:
- - A pandas.DataFrame of records, where each metric, parameter, and tag are expanded into their own columns named metrics.*, params.*, and tags.* respectively. For records that don't have a particular metric, parameter, or tag, their value will be (NumPy) Nan, None, or None respectively.
- - Use case:
-
- .. code-block:: Python
-
- R.log_metrics(m=2.50, step=0)
- records = R.search_runs([experiment_id], order_by=["metrics.m DESC"])
-
-- `list_experiments()`
- - Method for listing all the existing experiments (except for those being deleted.)
- - Returns:
- - A dictionary (name -> experiment) of experiments information that being stored.
- - Use case:
-
- .. code-block:: Python
-
- exps = R.list_experiments()
-
-- `list_recorders(experiment_id=None, experiment_name=None)`
- - Method for listing all the recorders of experiment with given id or name. If user doesn't provide the id or name of the experiment, this method will try to retrieve the default experiment and list all the recorders of the default experiment. If the default experiment doesn't exist, the method will first create the default experiment, and then create a new recorder under it.
- - Parameters:
- - `experiment_id` : str
- id of the experiment.
- - `experiment_name` : str
- name of the experiment.
- - Returns:
- - A dictionary (id -> recorder) of recorder information that being stored.
- - Use case:
-
- .. code-block:: Python
-
- recorders = R.list_recorders(experiment_name='test')
-
-- `get_exp(experiment_id=None, experiment_name=None, create: bool = True)`
- - Method for retrieving an experiment with given id or name. Once the '`create`' argument is set to True, if no valid experiment is found, this method will create one for the user. Otherwise, it will only retrieve a specific experiment or raise an Error.
-
- - If '`create`' is True:
- - If ``R``'s running:
- - no id or name specified, return the active experiment.
- - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be running.
- - If ``R``'s not running:
- - no id or name specified, create a default experiment, and the experiment is set to be running.
- - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment, and the experiment is set to be running.
- - Else If '`create`' is False:
- - If ``R``'s running:
- - no id or name specified, return the active experiment.
- - if id or name is specified, return the specified experiment. If no such exp found, raise Error.
- - If ``R``'s not running:
- - no id or name specified. If the default experiment exists, return it, otherwise, raise Error.
- - if id or name is specified, return the specified experiment. If no such exp found, raise Error.
- - Parameters:
- - `experiment_id` : str
- id of the experiment.
- - `experiment_name` : str
- name of the experiment.
- - `create` : boolean
- an argument determines whether the method will automatically create a new experiment according to user's specification if the experiment hasn't been created before.
- - Returns:
- - An experiment instance with given id or name.
- - Use case:
-
- .. code-block:: Python
-
- # Case 1
- with R.start('test'):
- exp = R.get_exp()
- recorders = exp.list_recorders()
-
- # Case 2
- with R.start('test'):
- exp = R.get_exp('test1')
-
- # Case 3
- exp = R.get_exp() -> a default experiment.
-
- # Case 4
- exp = R.get_exp(experiment_name='test')
-
- # Case 5
- exp = R.get_exp(create=False) -> the default experiment if exists.
-
-- `delete_exp(experiment_id=None, experiment_name=None)`
- - Method for deleting the experiment with given id or name. At least one of id or name must be given, otherwise, error will occur.
- - Parameters:
- - `experiment_id` : str
- id of the experiment.
- - `experiment_name` : str
- name of the experiment.
- - Use case:
-
- .. code-block:: Python
-
- R.delete_exp(experiment_name='test')
-
-- `get_uri()`
- - Method for retrieving the uri of current experiment manager.
- - Returns:
- - The uri of current experiment manager.
- - Use case:
-
- .. code-block:: Python
-
- uri = R.get_uri()
-
-- `get_recorder(recorder_id=None, recorder_name=None, experiment_name=None)`
- - Method for retrieving a recorder. The recorder can be used for further process such as ``save_objects``, ``load_object``, ``log_params``, ``log_metrics``, etc.
-
- - If ``R``'s running:
- - no id or name specified, return the active recorder.
- - if id or name is specified, return the specified recorder.
- - If ``R``'s not running:
- - no id or name specified, raise Error.
- - if id or name is specified, and the corresponding experiment_name must be given, return the specified recorder. Otherwise, raise Error.
- - Parameters:
- - `recorder_id` : str
- id of the recorder.
- - `recorder_name` : str
- name of the recorder.
- - `experiment_name` : str
- name of the experiment.
- - Returns:
- - A recorder instance.
- - Use case:
-
- .. code-block:: Python
-
- # Case 1
- with R.start('test'):
- recorder = R.get_recorder()
-
- # Case 2
- with R.start('test'):
- recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
-
- # Case 3
- recorder = R.get_recorder() -> Error
-
- # Case 4
- recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d') -> Error
-
- # Case 5
- recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')
-
-- `delete_recorder(recorder_id=None, recorder_name=None)`
- - Method for deleting the recorders with given id or name. At least one of id or name must be given, otherwise, error will occur.
- - Parameters:
- - `recorder_id` : str
- id of the experiment.
- - `recorder_name` : str
- name of the experiment.
- - Use case:
-
- .. code-block:: Python
-
- R.delete_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
-
-- `save_objects(local_path=None, artifact_path=None, **kwargs)`
- - Method for saving objects as artifacts in the experiment to the uri. It supports either saving from a local file/directory, or directly saving objects. User can use valid python's keywords arguments to specify the object to be saved as well as its name (name: value).
-
- - If R's running: it will save the objects through the running recorder.
- - If R's not running: the system will create a default experiment, and a new recorder and save objects under it.
-
- .. note::
-
- If one wants to save objects with a specific recorder. It is recommended to first get the specific recorder through `get_recorder` API and use the recorder the save objects. The supported arguments are the same as this method.
-
- - Parameters:
- - `local_path` : str
- if provided, them save the file or directory to the artifact URI.
- - `artifact_path` : str
- the relative path for the artifact to be stored in the URI.
- - Use case:
-
- .. code-block:: Python
-
- # Case 1
- with R.start('test'):
- pred = model.predict(dataset)
- R.save_objects(**{"pred.pkl": pred}, artifact_path='prediction')
-
- # Case 2
- with R.start('test'):
- R.save_objects(local_path='results/pred.pkl')
-
-- `log_params(**kwargs)`
- - Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
-
- - If R's running: it will log parameters through the running recorder.
- - If R's not running: the system will create a default experiment as well as a new recorder, and log parameters under it.
- - Parameters:
- - `keyword argument`:
- name1=value1, name2=value2, ...
- - Use case:
-
- .. code-block:: Python
-
- # Case 1
- with R.start('test'):
- R.log_params(learning_rate=0.01)
-
- # Case 2
- R.log_params(learning_rate=0.01)
-
-- `log_metrics(step=None, **kwargs)`
- - Method for logging metrics during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
-
- - If R's running: it will log metrics through the running recorder.
- - If R's not running: the system will create a default experiment as well as a new recorder, and log metrics under it.
- - Parameters:
- - `step`: int
- a single integer step at which to log the specified Metrics. If unspecified, each metric is logged at step zero.
- - `keyword argument`:
- name1=value1, name2=value2, ...
-
-- `set_tags(**kwargs)`
- - Method for setting tags for a recorder. In addition to using ``R``, one can also set the tag to a specific recorder after getting it with `get_recorder` API.
-
- - If R's running: it will set tags through the running recorder.
- - If R's not running: the system will create a default experiment as well as a new recorder, and set the tags under it.
- - Parameters:
- - `keyword argument`:
- name1=value1, name2=value2, ...
- - Use case:
-
- .. code-block:: Python
-
- # Case 1
- with R.start('test'):
- R.set_tags(release_version="2.2.0")
-
- # Case 2
- R.set_tags(release_version="2.2.0")
-
+.. autoclass:: qlib.workflow.__init__.QlibRecorder
+ :members:
Experiment Manager
===================
The ``ExpManager`` module in ``Qlib`` is responsible for managing different experiments. Most of the APIs of ``ExpManager`` are similar to ``QlibRecorder``, and the most important API will be the ``get_exp`` method. User can directly refer to the documents above for some detailed information about how to use the ``get_exp`` method.
+.. autoclass:: qlib.workflow.expm.ExpManager
+ :members: get_exp, list_experiments
+
For other interfaces such as `create_exp`, `delete_exp`, please refer to `Experiment Manager API <../reference/api.html#experiment-manager>`_.
Experiment
@@ -363,6 +68,9 @@ Experiment
The ``Experiment`` class is solely responsible for a single experiment, and it will handle any operations that are related to an experiment. Basic methods such as `start`, `end` an experiment are included. Besides, methods related to `recorders` are also available: such methods include `get_recorder` and `list_recorders`.
+.. autoclass:: qlib.workflow.exp.Experiment
+ :members: get_recorder, list_recorders
+
For other interfaces such as `search_records`, `delete_recorder`, please refer to `Experiment API <../reference/api.html#experiment>`_.
Recorder
@@ -372,28 +80,8 @@ The ``Recorder`` class is responsible for a single recorder. It will handle some
Here are some important APIs that are not included in the ``QlibRecorder``:
-- `list_artifacts(artifact_path: str = None)`
- - List all the artifacts of a recorder.
- - Parameters:
- - `artifact_path` : str
- the relative path for the artifact to be stored in the URI.
- - Returns:
- - A list of artifacts information (name, path, etc.) that being stored.
-
-- `list_metrics()`
- - List all the metrics of a recorder.
- - Returns:
- - A dictionary of metrics that being stored.
-
-- `list_params()`
- - List all the params of a recorder.
- - Returns:
- - A dictionary of params that being stored.
-
-- `list_tags()`
- - List all the tags of a recorder.
- - Returns:
- - A dictionary of tags that being stored.
+.. autoclass:: qlib.workflow.recorder.Recorder
+ :members: list_artifacts, list_metrics, list_params, list_tags
For other interfaces such as `save_objects`, `load_object`, please refer to `Recorder API <../reference/api.html#recorder>`_.
@@ -402,8 +90,8 @@ Record Template
The ``RecordTemp`` class is a class that enables generate experiment results such as IC and backtest in a certain format. We have provided three different `Record Template` class:
-- ``SignalRecord``: This class generates the `preidction` of the model.
-- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR`.
+- ``SignalRecord``: This class generates the `preidction` results of the model.
+- ``SigAnaRecord``: This class generates the `IC`, `ICIR`, `Rank IC` and `Rank ICIR` of the model.
- ``PortAnaRecord``: This class generates the results of `backtest`. The detailed information about `backtest` as well as the available `strategy`, users can refer to `Strategy <../component/strategy.html>`_ and `Backtest <../component/backtest.html>`_.
-For more information, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
\ No newline at end of file
+For more information about the APIs, please refer to `Record Template API <../reference/api.html#module-qlib.workflow.record_temp>`_.
\ No newline at end of file
diff --git a/docs/component/report.rst b/docs/component/report.rst
index 8ea3d7abe..7d8053c78 100644
--- a/docs/component/report.rst
+++ b/docs/component/report.rst
@@ -1,13 +1,13 @@
.. _report:
==========================================
-Aanalysis: Evaluation & Results Analysis
+Analysis: Evaluation & Results Analysis
==========================================
Introduction
===================
-``Aanalysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
+``Analysis`` is designed to show the graphical reports of ``Intraday Trading`` , which helps users to evaluate and analyse investment portfolios visually. The following are some graphics to view:
- analysis_position
- report_graph
diff --git a/docs/conf.py b/docs/conf.py
index b91efb9a9..5359d08ed 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -124,7 +124,7 @@ html_theme_options = {
"logo_only": True,
"collapse_navigation": False,
"display_version": False,
- "navigation_depth": 3,
+ "navigation_depth": 4,
}
# Add any paths that contain custom static files (such as style sheets) here,
diff --git a/docs/index.rst b/docs/index.rst
index 3a7358288..1e43cf99e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -41,7 +41,7 @@ Document Structure
Interday Strategy: Portfolio Management
Intraday Trading: Model&Strategy Testing
Qlib Recorder: Experiment Management
- Aanalysis: Evaluation & Results Analysis
+ Analysis: Evaluation & Results Analysis
.. toctree::
:maxdepth: 3
diff --git a/docs/introduction/introduction.rst b/docs/introduction/introduction.rst
index 3e4d11e28..06fac46fa 100644
--- a/docs/introduction/introduction.rst
+++ b/docs/introduction/introduction.rst
@@ -21,27 +21,27 @@ Framework
At the module level, Qlib is a platform that consists of above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
-====================== ==============================================================================
-Name Description
-====================== ==============================================================================
-`Data layer` `DataServer` focuses on providing high-performance infrastructure for users to
- manage and retrieve raw data. `DataEnhancement` will preprocess the data and
- provide the best dataset to be fed into the models.
-`Interday Model` `Interday model` focuses on producing prediction scores (aka. `alpha`). Models
- are trained by `Model Creator` and managed by `Model Manager`. Users could
- choose one or multiple models for prediction. Multiple models could be combined
- with `Ensemble` module.
-`Interday Strategy` `Portfolio Generator` will take prediction scores as input and output the
- orders based on the current position to achieve the target portfolio.
+======================== ==============================================================================
+Name Description
+======================== ==============================================================================
+`Infrastructure` layer `Infrastructure` layer provides underlying support for Quant research.
+ `DataServer` provides high-performance infrastructure for users to manage
+ and retrieve raw data. `Trainer` provides flexible interface to control
+ the training process of models which enable algorithms controlling the
+ training process.
-`Intraday Trading` `Order Executor` is responsible for executing orders output by
- `Interday Strategy` and returning the executed results.
+`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
+ `Information Extractor` extracts data for models. `Forecast Model` focuses
+ on producing all kinds of forecast signals (e.g. _alpha_, risk) for other
+ modules. With these signals `Portfolio Generator` will generate the target
+ portfolio and produce orders to be executed by `Order Executor`.
-`Analysis` Users could get a detailed analysis report of forecasting signals and portfolios
- in this part.
-====================== ==============================================================================
+`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
+ system. `Analyser` module will provide users detailed analysis reports of
+ forecasting signals, portfolios and execution results
+======================== ==============================================================================
- The modules with hand-drawn style are under development and will be released in the future.
- The modules with dashed borders are highly user-customizable and extendible.
diff --git a/docs/introduction/quick.rst b/docs/introduction/quick.rst
index a367e2dde..32752fd83 100644
--- a/docs/introduction/quick.rst
+++ b/docs/introduction/quick.rst
@@ -84,7 +84,7 @@ Auto Quant Research Workflow
- Run ``examples/workflow_by_code.ipynb`` with jupyter notebook
Users can have portfolio analysis or prediction score (model prediction) analysis by run ``examples/workflow_by_code.ipynb``.
- Graphical Reports
- Users can get graphical reports about the analysis, please refer to `Aanalysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
+ Users can get graphical reports about the analysis, please refer to `Analysis: Evaluation & Results Analysis <../component/report.html>`_ for more details.
diff --git a/docs/reference/api.rst b/docs/reference/api.rst
index 76d2a74a5..f21a9f518 100644
--- a/docs/reference/api.rst
+++ b/docs/reference/api.rst
@@ -23,16 +23,13 @@ Filter
.. automodule:: qlib.data.filter
:members:
-Feature
---------------------
-
Class
-~~~~~~~~~~~~~~~~~~~~
+--------------------
.. automodule:: qlib.data.base
:members:
Operator
-~~~~~~~~~~~~~~~~~~~~
+--------------------
.. automodule:: qlib.data.ops
:members:
@@ -56,16 +53,33 @@ Cache
.. autoclass:: qlib.data.cache.DiskDatasetCache
:members:
+Dataset
+---------------
+
+Dataset Class
+~~~~~~~~~~~~~~~~~~~~
+.. automodule:: qlib.data.dataset.__init__
+ :members:
+
+Data Loader
+~~~~~~~~~~~~~~~~~~~~
+.. automodule:: qlib.data.dataset.loader
+ :members:
+
+Data Handler
+~~~~~~~~~~~~~~~~~~~~
+.. automodule:: qlib.data.dataset.handler
+ :members:
+
+Processor
+~~~~~~~~~~~~~~~~~~~~
+.. automodule:: qlib.data.dataset.processor
+ :members:
+
Contrib
====================
-
-Data Handler
----------------
-.. automodule:: qlib.data.dataset.handler
- :members:
-
Model
--------------------
.. automodule:: qlib.model.base
diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst
index af89a098e..423d7edf8 100644
--- a/docs/start/initialization.rst
+++ b/docs/start/initialization.rst
@@ -12,14 +12,16 @@ Initialization
Please follow the steps below to initialize ``Qlib``.
-- Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance `_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>` for more information about customized dataset.
+Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance `_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>`_ for more information about customized dataset.
+
.. code-block:: bash
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
- Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
+
+Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`,
-- Initialize Qlib before calling other APIs: run following code in python.
+Initialize Qlib before calling other APIs: run following code in python.
.. code-block:: Python
diff --git a/docs/start/integration.rst b/docs/start/integration.rst
index 5276729b5..102d88425 100644
--- a/docs/start/integration.rst
+++ b/docs/start/integration.rst
@@ -5,7 +5,7 @@ Custom Model Integration
Introduction
===================
-``Qlib`` provides ``lightGBM`` and ``Dnn`` model as the baseline of ``Interday Model``. In addition to the default model, users can integrate their own custom models into ``Qlib``.
+``Qlib``'s `Model Zoo` includes models such as ``LightGBM``, ``DNN``, ``LSTM``, etc.. These models are treated as the baselines of ``Interday Model``. In addition to the default models ``Qlib`` provide, users can integrate their own custom models into ``Qlib``.
Users can integrate their own custom models according to the following steps.
@@ -32,79 +32,76 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
- Override the `fit` method
- ``Qlib`` calls the fit method to train the model
- - The parameters must include training feature `x_train`, training label `y_train`, test feature `x_valid`, test label `y_valid` at least.
- - The parameters could include some optional parameters with default values, such as train weight `w_train`, test weight `w_valid` and `num_boost_round = 1000`.
+ - The parameters must include training feature `dataset`.
+ - The parameters could include some optional parameters with default values, such as `num_boost_round = 1000` for `GBDT`.
- Code Example: In the following example, `num_boost_round = 1000` is an optional parameter.
.. code-block:: Python
- def fit(self, x_train:pd.DataFrame, y_train:pd.DataFrame, x_valid:pd.DataFrame, y_valid:pd.DataFrame,
- w_train:pd.DataFrame = None, w_valid:pd.DataFrame = None, num_boost_round = 1000, **kwargs):
+ def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
+
+ # prepare dataset for lgb training and evaluation
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
# Lightgbm need 1D array as its label
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
- y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values)
+ y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
else:
- raise ValueError('LightGBM doesn\'t support multi-label training')
+ raise ValueError("LightGBM doesn't support multi-label training")
- w_train_weight = None if w_train is None else w_train.values
- w_valid_weight = None if w_valid is None else w_valid.values
+ dtrain = lgb.Dataset(x_train.values, label=y_train)
+ dvalid = lgb.Dataset(x_valid.values, label=y_valid)
- dtrain = lgb.Dataset(x_train.values, label=y_train_1d, weight=w_train_weight)
- dvalid = lgb.Dataset(x_valid.values, label=y_valid_1d, weight=w_valid_weight)
- self._model = lgb.train(
- self._params,
- dtrain,
+ # fit the model
+ self.model = lgb.train(
+ self.params,
+ dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
- valid_names=['train', 'valid'],
+ valid_names=["train", "valid"],
+ early_stopping_rounds=early_stopping_rounds,
+ verbose_eval=verbose_eval,
+ evals_result=evals_result,
**kwargs
)
- Override the `predict` method
- - The parameters include the test features.
+ - The parameters must include training feature `dataset`, which will be userd to get the test dataset.
- Return the `prediction score`.
- Please refer to `Model API <../reference/api.html#module-qlib.model.base>`_ for the parameter types of the fit method.
- - Code Example: In the following example, users need to use dnn to predict the label(such as `preds`) of test data `x_test` and return it.
+ - Code Example: In the following example, users need to use `LightGBM` to predict the label(such as `preds`) of test data `x_test` and return it.
.. code-block:: Python
- def predict(self, x_test:pd.DataFrame, **kwargs)-> numpy.ndarray:
- if self._model is None:
- raise ValueError('model is not fitted yet!')
- return self._model.predict(x_test.values)
+ def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:
+ if self.model is None:
+ raise ValueError("model is not fitted yet!")
+ x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
+ return pd.Series(self.model.predict(x_test.values), index=x_test.index)
-- Override the `save` method & `load` method
- - The `save` method parameter includes the a `filename` that represents an absolute path, user need to save model into the path.
- - The `load` method parameter includes the a `buffer` read from the `filename` passed in the `save` method, users need to load model from the `buffer`.
- - Code Example:
+- Override the `finetune` method
+ - The parameters must include training feature `dataset`.
+ - Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
.. code-block:: Python
- def save(self, filename):
- if self._model is None:
- raise ValueError('model is not fitted yet!')
- self._model.save_model(filename)
-
- def load(self, buffer):
- self._model = lgb.Booster(params={'model_str': buffer.decode('utf-8')})
-
-.. Without tuner, this part will not be used
-.. - Override the `score` method(This step is optional)
-.. - The parameters include the test features and test labels.
-.. - Return the evaluation score of the model. It's recommended to adopt the loss between labels and `prediction score`.
-.. - Code Example: In the following example, users need to calculate the weighted loss with test data `x_test`, test label `y_test` and the weight `w_test`.
-.. .. code-block:: Python
-..
-.. def score(self, x_test:pd.Dataframe, y_test:pd.Dataframe, w_test:pd.DataFrame = None) -> float:
-.. # Remove rows from x, y and w, which contain Nan in any columns in y_test.
-.. x_test, y_test, w_test = drop_nan_by_y_index(x_test, y_test, w_test)
-.. preds = self.predict(x_test)
-.. w_test_weight = None if w_test is None else w_test.values
-.. scorer = mean_squared_error if self.loss_type == 'mse' else roc_auc_score
-.. return scorer(y_test.values, preds, sample_weight=w_test_weight)
+ def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
+ dtrain, _ = self._prepare_data(dataset)
+ self.model = lgb.train(
+ self.params,
+ dtrain,
+ num_boost_round=num_boost_round,
+ init_model=self.model,
+ valid_sets=[dtrain],
+ valid_names=["train"],
+ verbose_eval=verbose_eval,
+ )
Configuration File
=======================
-The configuration file is described in detail in the `estimator <../component/estimator.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file.
+The configuration file is described in detail in the `Workflow <../component/workflow.html#complete-example>`_ document. In order to integrate the custom model into ``Qlib``, users need to modify the "model" field in the configuration file.
- Example: The following example describes the `model` field of configuration file about the custom lightgbm model mentioned above, where `module_path` is the module path, `class` is the class name, and `args` is the hyperparameter passed into the __init__ method. All parameters in the field is passed to `self._params` by `\*\*kwargs` in `__init__` except `loss = mse`.
@@ -124,20 +121,20 @@ The configuration file is described in detail in the `estimator <../component/es
num_leaves: 210
num_threads: 20
-Users could find configuration file of the baseline of the ``Model`` in ``qlib/examples/estimator/estimator_config.yaml`` and ``qlib/examples/estimator/estimator_config_dnn.yaml``
+Users could find configuration file of the baselines of the ``Model`` in ``examples/benchmarks``. All the configurations of different models are listed under the corresponding model folder.
Model Testing
=====================
-Assuming that the configuration file is ``examples/estimator/estimator_config.yaml``, users can run the following command to test the custom model:
+Assuming that the configuration file is ``examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml``, users can run the following command to test the custom model:
.. code-block:: bash
cd examples # Avoid running program under the directory contains `qlib`
- estimator -c estimator/estimator_config.yaml
+ qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml
-.. note:: ``estimator`` is a built-in command of ``Qlib``.
+.. note:: ``qrun`` is a built-in command of ``Qlib``.
-Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/train_backtest_analyze.ipynb``.
+Also, ``Model`` can also be tested as a single module. An example has been given in ``examples/workflow_by_code.ipynb``.
Reference
diff --git a/examples/benchmarks/ALSTM/README.md b/examples/benchmarks/ALSTM/README.md
new file mode 100644
index 000000000..cd9dd3493
--- /dev/null
+++ b/examples/benchmarks/ALSTM/README.md
@@ -0,0 +1,10 @@
+# ALSTM
+
+- ALSTM contains a temporal attentive aggregation layer based on normal LSTM.
+
+- The code used in Qlib is a pyTorch implementation of Code: https://github.com/fulifeng/Adv-ALSTM
+
+- Paper: A dual-stage attention-based recurrent neural network for time series prediction.
+
+ https://www.ijcai.org/Proceedings/2017/0366.pdf
+
diff --git a/examples/benchmarks/ALSTM/requirements.txt b/examples/benchmarks/ALSTM/requirements.txt
new file mode 100644
index 000000000..1fc2779c0
--- /dev/null
+++ b/examples/benchmarks/ALSTM/requirements.txt
@@ -0,0 +1,4 @@
+numpy==1.17.4
+pandas==1.1.2
+scikit_learn==0.23.2
+torch==1.7.0
diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm.yaml
new file mode 100644
index 000000000..dd57761f3
--- /dev/null
+++ b/examples/benchmarks/ALSTM/workflow_config_alstm.yaml
@@ -0,0 +1,83 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: ALSTM
+ module_path: qlib.contrib.model.pytorch_alstm
+ kwargs:
+ d_feat: 6
+ hidden_size: 64
+ num_layers: 2
+ dropout: 0.0
+ n_epochs: 200
+ lr: 1e-3
+ early_stop: 20
+ batch_size: 800
+ metric: loss
+ loss: mse
+ seed: 0
+ GPU: 0
+ rnn_type: GRU
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: ALPHA360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost.yaml
index 8bf3bb72b..9c15dc25b 100644
--- a/examples/benchmarks/CatBoost/workflow_config_catboost.yaml
+++ b/examples/benchmarks/CatBoost/workflow_config_catboost.yaml
@@ -30,8 +30,13 @@ task:
module_path: qlib.contrib.model.catboost_model
kwargs:
loss: RMSE
- iterations: 5
- learning_rate: 0.03
+ learning_rate: 0.0421
+ subsample: 0.8789
+ max_depth: 6
+ num_leaves: 100
+ thread_count: 20
+ grow_policy: Lossguide
+ bootstrap_type: Poisson
dataset:
class: DatasetH
module_path: qlib.data.dataset
@@ -56,4 +61,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- config: *port_analysis_config
\ No newline at end of file
+ config: *port_analysis_config
diff --git a/examples/benchmarks/GATs/README.md b/examples/benchmarks/GATs/README.md
new file mode 100644
index 000000000..f432b6c5b
--- /dev/null
+++ b/examples/benchmarks/GATs/README.md
@@ -0,0 +1,5 @@
+# GATs
+* Graph Attention Networks(GATs) leverage masked self-attentional layers on graph-structured data. The nodes in stacked layers have different weights and they are able to attend over their
+neighborhoods’ features, without requiring any kind of costly matrix operation (such as inversion) or depending on knowing the graph structure upfront.
+* This code used in Qlib is implemented with PyTorch by ourselves.
+* Paper: Graph Attention Networks https://arxiv.org/pdf/1710.10903.pdf
\ No newline at end of file
diff --git a/examples/benchmarks/GATs/workflow_config_gats.yaml b/examples/benchmarks/GATs/workflow_config_gats.yaml
index 37bced99d..33aa0fe8d 100644
--- a/examples/benchmarks/GATs/workflow_config_gats.yaml
+++ b/examples/benchmarks/GATs/workflow_config_gats.yaml
@@ -36,7 +36,6 @@ task:
n_epochs: 200
lr: 1e-3
early_stop: 20
- batch_size: 800
metric: loss
loss: mse
base_model: LSTM
diff --git a/examples/benchmarks/GRU/workflow_config_gru.yaml b/examples/benchmarks/GRU/workflow_config_gru.yaml
index e9e6224e6..bdfcd4e55 100644
--- a/examples/benchmarks/GRU/workflow_config_gru.yaml
+++ b/examples/benchmarks/GRU/workflow_config_gru.yaml
@@ -8,6 +8,20 @@ data_handler_config: &data_handler_config
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
@@ -37,7 +51,7 @@ task:
lr: 1e-3
early_stop: 20
batch_size: 800
- metric: IC
+ metric: loss
loss: mse
seed: 0
GPU: 0
@@ -46,7 +60,7 @@ task:
module_path: qlib.data.dataset
kwargs:
handler:
- class: ALPHA360_Denoise
+ class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
diff --git a/examples/benchmarks/HATS/README.md b/examples/benchmarks/HATS/README.md
new file mode 100644
index 000000000..b70dbff25
--- /dev/null
+++ b/examples/benchmarks/HATS/README.md
@@ -0,0 +1,15 @@
+## Requirement
+
+* pandas==1.1.2
+* numpy==1.17.4
+* scikit_learn==0.23.2
+* torch==1.7.0
+
+## HATS
+
+* HATS is a a hierarchical attention network for stock prediction which uses relational data for stock market prediction. HATS selectively aggregates information
+on different relation types and adds the information to the representations of each company. HATS is used as a relational modeling module with initialized node representations.Furthermore, HATS
+can predict not only individual stock prices but also market index movements, which is similar to the graph classification task.
+
+* HATS uses pretrained model of GRU and LSTM. The code of GRU and LSTM used in Qlib is a pyTorch implemention of GRU and LSTM.
+* Paper address:HATS: A Hierarchical Graph Attention Network for Stock Movement Prediction https://arxiv.org/pdf/1908.07999.pdf
\ No newline at end of file
diff --git a/examples/benchmarks/HATS/requirements.txt b/examples/benchmarks/HATS/requirements.txt
new file mode 100644
index 000000000..16de0a438
--- /dev/null
+++ b/examples/benchmarks/HATS/requirements.txt
@@ -0,0 +1,4 @@
+pandas==1.1.2
+numpy==1.17.4
+scikit_learn==0.23.2
+torch==1.7.0
diff --git a/examples/benchmarks/HATS/worflow_config_hats.yaml b/examples/benchmarks/HATS/worflow_config_hats.yaml
new file mode 100644
index 000000000..b08df14e0
--- /dev/null
+++ b/examples/benchmarks/HATS/worflow_config_hats.yaml
@@ -0,0 +1,77 @@
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: HATS
+ module_path: qlib.contrib.model.pytorch_hats
+ kwargs:
+ d_feat: 6
+ hidden_size: 64
+ num_layers: 2
+ dropout: 0.6
+ n_epochs: 200
+ lr: 1e-3
+ early_stop: 20
+ metric: loss
+ loss: mse
+ base_model: GRU
+ seed: 0
+ GPU: 0
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: ALPHA360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/benchmarks/LSTM/workflow_config_lstm.yaml b/examples/benchmarks/LSTM/workflow_config_lstm.yaml
index 354149dae..6512a0df3 100644
--- a/examples/benchmarks/LSTM/workflow_config_lstm.yaml
+++ b/examples/benchmarks/LSTM/workflow_config_lstm.yaml
@@ -8,6 +8,20 @@ data_handler_config: &data_handler_config
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
@@ -37,7 +51,7 @@ task:
lr: 1e-3
early_stop: 20
batch_size: 800
- metric: IC
+ metric: loss
loss: mse
seed: 0
GPU: 0
@@ -46,7 +60,7 @@ task:
module_path: qlib.data.dataset
kwargs:
handler:
- class: ALPHA360_Denoise
+ class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
diff --git a/examples/benchmarks/SFM/README.md b/examples/benchmarks/SFM/README.md
new file mode 100644
index 000000000..06ca50485
--- /dev/null
+++ b/examples/benchmarks/SFM/README.md
@@ -0,0 +1,4 @@
+# State-Frequency-Memory
+- State Frequency Memory (SFM) is a novel recurrent network that uses Discrete Fourier Transform (DFT) to decompose the hidden states of memory cells and capture the multi-frequency trading patterns from past market data to make stock price predictions.
+- The code used in Qlib is a pyTorch implementation of SFM (Zhang, L., Aggarwal, C., & Qi, G. J. (2017,)).
+- Paper: Stock Price Prediction via Discovering Multi-Frequency Trading Patterns. https://www.cs.ucf.edu/~gqi/publications/kdd2017_stock.pdf.
\ No newline at end of file
diff --git a/examples/benchmarks/SFM/workflow_config_sfm.yaml b/examples/benchmarks/SFM/workflow_config_sfm.yaml
index 9086bab4a..3fa3f932c 100644
--- a/examples/benchmarks/SFM/workflow_config_sfm.yaml
+++ b/examples/benchmarks/SFM/workflow_config_sfm.yaml
@@ -8,6 +8,20 @@ data_handler_config: &data_handler_config
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1"]
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
@@ -31,27 +45,25 @@ task:
kwargs:
d_feat: 6
hidden_size: 64
- output_dim: 1
- freq_dim: 15
+ output_dim: 32
+ freq_dim: 25
dropout_W: 0.5
dropout_U: 0.5
- n_epochs: 10
+ n_epochs: 20
lr: 1e-3
- batch_size: 800
+ batch_size: 1600
early_stop: 20
eval_steps: 5
loss: mse
- lr_decay: 0.96
- lr_decay_steps: 100
- optimizer: gd
+ optimizer: adam
GPU: 1
- seed: 0
+ seed: 710
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
- class: ALPHA360_Denoise
+ class: ALPHA360
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
@@ -70,4 +82,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
- config: *port_analysis_config
\ No newline at end of file
+ config: *port_analysis_config
diff --git a/examples/benchmarks/TFT/README.md b/examples/benchmarks/TFT/README.md
new file mode 100644
index 000000000..5a6a9f153
--- /dev/null
+++ b/examples/benchmarks/TFT/README.md
@@ -0,0 +1,14 @@
+# Temporal Fusion Transformers Benchmark
+## Source
+**Reference**: Lim, Bryan, et al. "Temporal fusion transformers for interpretable multi-horizon time series forecasting." arXiv preprint arXiv:1912.09363 (2019).
+
+**GitHub**: https://github.com/google-research/google-research/tree/master/tft
+
+## Run the Workflow
+Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
+
+### Notes
+1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
+2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
+3. The model must run in GPU, or an error will be raised.
+4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
diff --git a/examples/benchmarks/TFT/data_formatters/__init__.py b/examples/benchmarks/TFT/data_formatters/__init__.py
new file mode 100644
index 000000000..87ec3284f
--- /dev/null
+++ b/examples/benchmarks/TFT/data_formatters/__init__.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/examples/benchmarks/TFT/data_formatters/base.py b/examples/benchmarks/TFT/data_formatters/base.py
new file mode 100644
index 000000000..c68a192ba
--- /dev/null
+++ b/examples/benchmarks/TFT/data_formatters/base.py
@@ -0,0 +1,223 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Default data formatting functions for experiments.
+
+For new datasets, inherit form GenericDataFormatter and implement
+all abstract functions.
+
+These dataset-specific methods:
+1) Define the column and input types for tabular dataframes used by model
+2) Perform the necessary input feature engineering & normalisation steps
+3) Reverts the normalisation for predictions
+4) Are responsible for train, validation and test splits
+
+
+"""
+
+import abc
+import enum
+
+
+# Type defintions
+class DataTypes(enum.IntEnum):
+ """Defines numerical types of each column."""
+
+ REAL_VALUED = 0
+ CATEGORICAL = 1
+ DATE = 2
+
+
+class InputTypes(enum.IntEnum):
+ """Defines input types of each column."""
+
+ TARGET = 0
+ OBSERVED_INPUT = 1
+ KNOWN_INPUT = 2
+ STATIC_INPUT = 3
+ ID = 4 # Single column used as an entity identifier
+ TIME = 5 # Single column exclusively used as a time index
+
+
+class GenericDataFormatter(abc.ABC):
+ """Abstract base class for all data formatters.
+
+ User can implement the abstract methods below to perform dataset-specific
+ manipulations.
+
+ """
+
+ @abc.abstractmethod
+ def set_scalers(self, df):
+ """Calibrates scalers using the data supplied."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def transform_inputs(self, df):
+ """Performs feature transformation."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def format_predictions(self, df):
+ """Reverts any normalisation to give predictions in original scale."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def split_data(self, df):
+ """Performs the default train, validation and test splits."""
+ raise NotImplementedError()
+
+ @property
+ @abc.abstractmethod
+ def _column_definition(self):
+ """Defines order, input type and data type of each column."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_fixed_params(self):
+ """Defines the fixed parameters used by the model for training.
+
+ Requires the following keys:
+ 'total_time_steps': Defines the total number of time steps used by TFT
+ 'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
+ 'num_epochs': Maximum number of epochs for training
+ 'early_stopping_patience': Early stopping param for keras
+ 'multiprocessing_workers': # of cpus for data processing
+
+
+ Returns:
+ A dictionary of fixed parameters, e.g.:
+
+ fixed_params = {
+ 'total_time_steps': 252 + 5,
+ 'num_encoder_steps': 252,
+ 'num_epochs': 100,
+ 'early_stopping_patience': 5,
+ 'multiprocessing_workers': 5,
+ }
+ """
+ raise NotImplementedError
+
+ # Shared functions across data-formatters
+ @property
+ def num_classes_per_cat_input(self):
+ """Returns number of categories per relevant input.
+
+ This is seqeuently required for keras embedding layers.
+ """
+ return self._num_classes_per_cat_input
+
+ def get_num_samples_for_calibration(self):
+ """Gets the default number of training and validation samples.
+
+ Use to sub-sample the data for network calibration and a value of -1 uses
+ all available samples.
+
+ Returns:
+ Tuple of (training samples, validation samples)
+ """
+ return -1, -1
+
+ def get_column_definition(self):
+ """"Returns formatted column definition in order expected by the TFT."""
+
+ column_definition = self._column_definition
+
+ # Sanity checks first.
+ # Ensure only one ID and time column exist
+ def _check_single_column(input_type):
+
+ length = len([tup for tup in column_definition if tup[2] == input_type])
+
+ if length != 1:
+ raise ValueError("Illegal number of inputs ({}) of type {}".format(length, input_type))
+
+ _check_single_column(InputTypes.ID)
+ _check_single_column(InputTypes.TIME)
+
+ identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
+ time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
+ real_inputs = [
+ tup
+ for tup in column_definition
+ if tup[1] == DataTypes.REAL_VALUED and tup[2] not in {InputTypes.ID, InputTypes.TIME}
+ ]
+ categorical_inputs = [
+ tup
+ for tup in column_definition
+ if tup[1] == DataTypes.CATEGORICAL and tup[2] not in {InputTypes.ID, InputTypes.TIME}
+ ]
+
+ return identifier + time + real_inputs + categorical_inputs
+
+ def _get_input_columns(self):
+ """Returns names of all input columns."""
+ return [tup[0] for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
+
+ def _get_tft_input_indices(self):
+ """Returns the relevant indexes and input sizes required by TFT."""
+
+ # Functions
+ def _extract_tuples_from_data_type(data_type, defn):
+ return [tup for tup in defn if tup[1] == data_type and tup[2] not in {InputTypes.ID, InputTypes.TIME}]
+
+ def _get_locations(input_types, defn):
+ return [i for i, tup in enumerate(defn) if tup[2] in input_types]
+
+ # Start extraction
+ column_definition = [
+ tup for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}
+ ]
+
+ categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, column_definition)
+ real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, column_definition)
+
+ locations = {
+ "input_size": len(self._get_input_columns()),
+ "output_size": len(_get_locations({InputTypes.TARGET}, column_definition)),
+ "category_counts": self.num_classes_per_cat_input,
+ "input_obs_loc": _get_locations({InputTypes.TARGET}, column_definition),
+ "static_input_loc": _get_locations({InputTypes.STATIC_INPUT}, column_definition),
+ "known_regular_inputs": _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs),
+ "known_categorical_inputs": _get_locations(
+ {InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, categorical_inputs
+ ),
+ }
+
+ return locations
+
+ def get_experiment_params(self):
+ """Returns fixed model parameters for experiments."""
+
+ required_keys = [
+ "total_time_steps",
+ "num_encoder_steps",
+ "num_epochs",
+ "early_stopping_patience",
+ "multiprocessing_workers",
+ ]
+
+ fixed_params = self.get_fixed_params()
+
+ for k in required_keys:
+ if k not in fixed_params:
+ raise ValueError("Field {}".format(k) + " missing from fixed parameter definitions!")
+
+ fixed_params["column_definition"] = self.get_column_definition()
+
+ fixed_params.update(self._get_tft_input_indices())
+
+ return fixed_params
diff --git a/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py b/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py
new file mode 100644
index 000000000..da3d14343
--- /dev/null
+++ b/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py
@@ -0,0 +1,219 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Custom formatting functions for Alpha158 dataset.
+
+Defines dataset specific column definitions and data transformations.
+"""
+
+import data_formatters.base
+import libs.utils as utils
+import sklearn.preprocessing
+
+GenericDataFormatter = data_formatters.base.GenericDataFormatter
+DataTypes = data_formatters.base.DataTypes
+InputTypes = data_formatters.base.InputTypes
+
+
+class Alpha158Formatter(GenericDataFormatter):
+ """Defines and formats data for the Alpha158 dataset.
+
+ Attributes:
+ column_definition: Defines input and data type of column used in the
+ experiment.
+ identifiers: Entity identifiers used in experiments.
+ """
+
+ _column_definition = [
+ ("instrument", DataTypes.CATEGORICAL, InputTypes.ID),
+ ("LABEL0", DataTypes.REAL_VALUED, InputTypes.TARGET),
+ ("date", DataTypes.DATE, InputTypes.TIME),
+ ("month", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
+ ("day_of_week", DataTypes.CATEGORICAL, InputTypes.KNOWN_INPUT),
+ # Selected 10 features
+ ("RESI5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("WVMA5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("RSQR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("KLEN", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("RSQR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("CORR5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("CORD5", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("CORR10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("ROC60", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("RESI10", DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
+ ("const", DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
+ ]
+
+ def __init__(self):
+ """Initialises formatter."""
+
+ self.identifiers = None
+ self._real_scalers = None
+ self._cat_scalers = None
+ self._target_scaler = None
+ self._num_classes_per_cat_input = None
+
+ def split_data(self, df, valid_boundary=2016, test_boundary=2018):
+ """Splits data frame into training-validation-test data frames.
+
+ This also calibrates scaling object, and transforms data for each split.
+
+ Args:
+ df: Source data frame to split.
+ valid_boundary: Starting year for validation data
+ test_boundary: Starting year for test data
+
+ Returns:
+ Tuple of transformed (train, valid, test) data.
+ """
+
+ print("Formatting train-valid-test splits.")
+
+ index = df["year"]
+ train = df.loc[index < valid_boundary]
+ valid = df.loc[(index >= valid_boundary) & (index < test_boundary)]
+ test = df.loc[index >= test_boundary]
+
+ self.set_scalers(train)
+
+ return (self.transform_inputs(data) for data in [train, valid, test])
+
+ def set_scalers(self, df):
+ """Calibrates scalers using the data supplied.
+
+ Args:
+ df: Data to use to calibrate scalers.
+ """
+ print("Setting scalers with training data...")
+
+ column_definitions = self.get_column_definition()
+ id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions)
+ target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions)
+
+ # Extract identifiers in case required
+ self.identifiers = list(df[id_column].unique())
+
+ # Format real scalers
+ real_inputs = utils.extract_cols_from_data_type(
+ DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
+ )
+
+ data = df[real_inputs].values
+ self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
+ self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
+ df[[target_column]].values
+ ) # used for predictions
+
+ # Format categorical scalers
+ categorical_inputs = utils.extract_cols_from_data_type(
+ DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
+ )
+
+ categorical_scalers = {}
+ num_classes = []
+ for col in categorical_inputs:
+ # Set all to str so that we don't have mixed integer/string columns
+ srs = df[col].apply(str)
+ categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values)
+ num_classes.append(srs.nunique())
+
+ # Set categorical scaler outputs
+ self._cat_scalers = categorical_scalers
+ self._num_classes_per_cat_input = num_classes
+
+ def transform_inputs(self, df):
+ """Performs feature transformations.
+
+ This includes both feature engineering, preprocessing and normalisation.
+
+ Args:
+ df: Data frame to transform.
+
+ Returns:
+ Transformed data frame.
+
+ """
+ output = df.copy()
+
+ if self._real_scalers is None and self._cat_scalers is None:
+ raise ValueError("Scalers have not been set!")
+
+ column_definitions = self.get_column_definition()
+
+ real_inputs = utils.extract_cols_from_data_type(
+ DataTypes.REAL_VALUED, column_definitions, {InputTypes.ID, InputTypes.TIME}
+ )
+ categorical_inputs = utils.extract_cols_from_data_type(
+ DataTypes.CATEGORICAL, column_definitions, {InputTypes.ID, InputTypes.TIME}
+ )
+
+ # Format real inputs
+ output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)
+
+ # Format categorical inputs
+ for col in categorical_inputs:
+ string_df = df[col].apply(str)
+ output[col] = self._cat_scalers[col].transform(string_df)
+
+ return output
+
+ def format_predictions(self, predictions):
+ """Reverts any normalisation to give predictions in original scale.
+
+ Args:
+ predictions: Dataframe of model predictions.
+
+ Returns:
+ Data frame of unnormalised predictions.
+ """
+ output = predictions.copy()
+
+ column_names = predictions.columns
+
+ for col in column_names:
+ if col not in {"forecast_time", "identifier"}:
+ output[col] = self._target_scaler.inverse_transform(predictions[col])
+
+ return output
+
+ # Default params
+ def get_fixed_params(self):
+ """Returns fixed model parameters for experiments."""
+
+ fixed_params = {
+ "total_time_steps": 16 + 6,
+ "num_encoder_steps": 16,
+ "num_epochs": 100,
+ "early_stopping_patience": 5,
+ "multiprocessing_workers": 5,
+ }
+
+ return fixed_params
+
+ def get_default_model_params(self):
+ """Returns default optimised model parameters."""
+
+ model_params = {
+ "dropout_rate": 0.3,
+ "hidden_layer_size": 160,
+ "learning_rate": 0.001,
+ "minibatch_size": 64,
+ "max_gradient_norm": 0.01,
+ "num_heads": 1,
+ "stack_size": 1,
+ }
+
+ return model_params
diff --git a/examples/benchmarks/TFT/expt_settings/__init__.py b/examples/benchmarks/TFT/expt_settings/__init__.py
new file mode 100644
index 000000000..87ec3284f
--- /dev/null
+++ b/examples/benchmarks/TFT/expt_settings/__init__.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/examples/benchmarks/TFT/expt_settings/configs.py b/examples/benchmarks/TFT/expt_settings/configs.py
new file mode 100644
index 000000000..6aef0c395
--- /dev/null
+++ b/examples/benchmarks/TFT/expt_settings/configs.py
@@ -0,0 +1,95 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Default configs for TFT experiments.
+
+Contains the default output paths for data, serialised models and predictions
+for the main experiments used in the publication.
+"""
+
+import os
+
+import data_formatters.qlib_Alpha158
+
+
+class ExperimentConfig(object):
+ """Defines experiment configs and paths to outputs.
+
+ Attributes:
+ root_folder: Root folder to contain all experimental outputs.
+ experiment: Name of experiment to run.
+ data_folder: Folder to store data for experiment.
+ model_folder: Folder to store serialised models.
+ results_folder: Folder to store results.
+ data_csv_path: Path to primary data csv file used in experiment.
+ hyperparam_iterations: Default number of random search iterations for
+ experiment.
+ """
+
+ default_experiments = ["Alpha158"]
+
+ def __init__(self, experiment="volatility", root_folder=None):
+ """Creates configs based on default experiment chosen.
+
+ Args:
+ experiment: Name of experiment.
+ root_folder: Root folder to save all outputs of training.
+ """
+
+ if experiment not in self.default_experiments:
+ raise ValueError("Unrecognised experiment={}".format(experiment))
+
+ # Defines all relevant paths
+ if root_folder is None:
+ root_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "outputs")
+ print("Using root folder {}".format(root_folder))
+
+ self.root_folder = root_folder
+ self.experiment = experiment
+ self.data_folder = os.path.join(root_folder, "data", experiment)
+ self.model_folder = os.path.join(root_folder, "saved_models", experiment)
+ self.results_folder = os.path.join(root_folder, "results", experiment)
+
+ # Creates folders if they don't exist
+ for relevant_directory in [self.root_folder, self.data_folder, self.model_folder, self.results_folder]:
+ if not os.path.exists(relevant_directory):
+ os.makedirs(relevant_directory)
+
+ @property
+ def data_csv_path(self):
+ csv_map = {
+ "Alpha158": "Alpha158.csv",
+ }
+
+ return os.path.join(self.data_folder, csv_map[self.experiment])
+
+ @property
+ def hyperparam_iterations(self):
+
+ return 240 if self.experiment == "volatility" else 60
+
+ def make_data_formatter(self):
+ """Gets a data formatter object for experiment.
+
+ Returns:
+ Default DataFormatter per experiment.
+ """
+
+ data_formatter_class = {
+ "Alpha158": data_formatters.qlib_Alpha158.Alpha158Formatter,
+ }
+
+ return data_formatter_class[self.experiment]()
diff --git a/examples/benchmarks/TFT/libs/__init__.py b/examples/benchmarks/TFT/libs/__init__.py
new file mode 100644
index 000000000..87ec3284f
--- /dev/null
+++ b/examples/benchmarks/TFT/libs/__init__.py
@@ -0,0 +1,14 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/examples/benchmarks/TFT/libs/hyperparam_opt.py b/examples/benchmarks/TFT/libs/hyperparam_opt.py
new file mode 100644
index 000000000..750fdf2c1
--- /dev/null
+++ b/examples/benchmarks/TFT/libs/hyperparam_opt.py
@@ -0,0 +1,430 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Classes used for hyperparameter optimisation.
+
+Two main classes exist:
+1) HyperparamOptManager used for optimisation on a single machine/GPU.
+2) DistributedHyperparamOptManager for multiple GPUs on different machines.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+import shutil
+import libs.utils as utils
+import numpy as np
+import pandas as pd
+
+Deque = collections.deque
+
+
+class HyperparamOptManager:
+ """Manages hyperparameter optimisation using random search for a single GPU.
+
+ Attributes:
+ param_ranges: Discrete hyperparameter range for random search.
+ results: Dataframe of validation results.
+ fixed_params: Fixed model parameters per experiment.
+ saved_params: Dataframe of parameters trained.
+ best_score: Minimum validation loss observed thus far.
+ optimal_name: Key to best configuration.
+ hyperparam_folder: Where to save optimisation outputs.
+ """
+
+ def __init__(self, param_ranges, fixed_params, model_folder, override_w_fixed_params=True):
+ """Instantiates model.
+
+ Args:
+ param_ranges: Discrete hyperparameter range for random search.
+ fixed_params: Fixed model parameters per experiment.
+ model_folder: Folder to store optimisation artifacts.
+ override_w_fixed_params: Whether to override serialsed fixed model
+ parameters with new supplied values.
+ """
+
+ self.param_ranges = param_ranges
+
+ self._max_tries = 1000
+ self.results = pd.DataFrame()
+ self.fixed_params = fixed_params
+ self.saved_params = pd.DataFrame()
+
+ self.best_score = np.Inf
+ self.optimal_name = ""
+
+ # Setup
+ # Create folder for saving if its not there
+ self.hyperparam_folder = model_folder
+ utils.create_folder_if_not_exist(self.hyperparam_folder)
+
+ self._override_w_fixed_params = override_w_fixed_params
+
+ def load_results(self):
+ """Loads results from previous hyperparameter optimisation.
+
+ Returns:
+ A boolean indicating if previous results can be loaded.
+ """
+ print("Loading results from", self.hyperparam_folder)
+
+ results_file = os.path.join(self.hyperparam_folder, "results.csv")
+ params_file = os.path.join(self.hyperparam_folder, "params.csv")
+
+ if os.path.exists(results_file) and os.path.exists(params_file):
+
+ self.results = pd.read_csv(results_file, index_col=0)
+ self.saved_params = pd.read_csv(params_file, index_col=0)
+
+ if not self.results.empty:
+ self.results.at["loss"] = self.results.loc["loss"].apply(float)
+ self.best_score = self.results.loc["loss"].min()
+
+ is_optimal = self.results.loc["loss"] == self.best_score
+ self.optimal_name = self.results.T[is_optimal].index[0]
+
+ return True
+
+ return False
+
+ def _get_params_from_name(self, name):
+ """Returns previously saved parameters given a key."""
+ params = self.saved_params
+
+ selected_params = dict(params[name])
+
+ if self._override_w_fixed_params:
+ for k in self.fixed_params:
+ selected_params[k] = self.fixed_params[k]
+
+ return selected_params
+
+ def get_best_params(self):
+ """Returns the optimal hyperparameters thus far."""
+
+ optimal_name = self.optimal_name
+
+ return self._get_params_from_name(optimal_name)
+
+ def clear(self):
+ """Clears all previous results and saved parameters."""
+ shutil.rmtree(self.hyperparam_folder)
+ os.makedirs(self.hyperparam_folder)
+ self.results = pd.DataFrame()
+ self.saved_params = pd.DataFrame()
+
+ def _check_params(self, params):
+ """Checks that parameter map is properly defined."""
+
+ valid_fields = list(self.param_ranges.keys()) + list(self.fixed_params.keys())
+ invalid_fields = [k for k in params if k not in valid_fields]
+ missing_fields = [k for k in valid_fields if k not in params]
+
+ if invalid_fields:
+ raise ValueError("Invalid Fields Found {} - Valid ones are {}".format(invalid_fields, valid_fields))
+ if missing_fields:
+ raise ValueError("Missing Fields Found {} - Valid ones are {}".format(missing_fields, valid_fields))
+
+ def _get_name(self, params):
+ """Returns a unique key for the supplied set of params."""
+
+ self._check_params(params)
+
+ fields = list(params.keys())
+ fields.sort()
+
+ return "_".join([str(params[k]) for k in fields])
+
+ def get_next_parameters(self, ranges_to_skip=None):
+ """Returns the next set of parameters to optimise.
+
+ Args:
+ ranges_to_skip: Explicitly defines a set of keys to skip.
+ """
+ if ranges_to_skip is None:
+ ranges_to_skip = set(self.results.index)
+
+ if not isinstance(self.param_ranges, dict):
+ raise ValueError("Only works for random search!")
+
+ param_range_keys = list(self.param_ranges.keys())
+ param_range_keys.sort()
+
+ def _get_next():
+ """Returns next hyperparameter set per try."""
+
+ parameters = {k: np.random.choice(self.param_ranges[k]) for k in param_range_keys}
+
+ # Adds fixed params
+ for k in self.fixed_params:
+ parameters[k] = self.fixed_params[k]
+
+ return parameters
+
+ for _ in range(self._max_tries):
+
+ parameters = _get_next()
+ name = self._get_name(parameters)
+
+ if name not in ranges_to_skip:
+ return parameters
+
+ raise ValueError("Exceeded max number of hyperparameter searches!!")
+
+ def update_score(self, parameters, loss, model, info=""):
+ """Updates the results from last optimisation run.
+
+ Args:
+ parameters: Hyperparameters used in optimisation.
+ loss: Validation loss obtained.
+ model: Model to serialised if required.
+ info: Any ancillary information to tag on to results.
+
+ Returns:
+ Boolean flag indicating if the model is the best seen so far.
+ """
+
+ if np.isnan(loss):
+ loss = np.Inf
+
+ if not os.path.isdir(self.hyperparam_folder):
+ os.makedirs(self.hyperparam_folder)
+
+ name = self._get_name(parameters)
+
+ is_optimal = self.results.empty or loss < self.best_score
+
+ # save the first model
+ if is_optimal:
+ # Try saving first, before updating info
+ if model is not None:
+ print("Optimal model found, updating")
+ model.save(self.hyperparam_folder)
+ self.best_score = loss
+ self.optimal_name = name
+
+ self.results[name] = pd.Series({"loss": loss, "info": info})
+ self.saved_params[name] = pd.Series(parameters)
+
+ self.results.to_csv(os.path.join(self.hyperparam_folder, "results.csv"))
+ self.saved_params.to_csv(os.path.join(self.hyperparam_folder, "params.csv"))
+
+ return is_optimal
+
+
+class DistributedHyperparamOptManager(HyperparamOptManager):
+ """Manages distributed hyperparameter optimisation across many gpus."""
+
+ def __init__(
+ self,
+ param_ranges,
+ fixed_params,
+ root_model_folder,
+ worker_number,
+ search_iterations=1000,
+ num_iterations_per_worker=5,
+ clear_serialised_params=False,
+ ):
+ """Instantiates optimisation manager.
+
+ This hyperparameter optimisation pre-generates #search_iterations
+ hyperparameter combinations and serialises them
+ at the start. At runtime, each worker goes through their own set of
+ parameter ranges. The pregeneration
+ allows for multiple workers to run in parallel on different machines without
+ resulting in parameter overlaps.
+
+ Args:
+ param_ranges: Discrete hyperparameter range for random search.
+ fixed_params: Fixed model parameters per experiment.
+ root_model_folder: Folder to store optimisation artifacts.
+ worker_number: Worker index definining which set of hyperparameters to
+ test.
+ search_iterations: Maximum numer of random search iterations.
+ num_iterations_per_worker: How many iterations are handled per worker.
+ clear_serialised_params: Whether to regenerate hyperparameter
+ combinations.
+ """
+
+ max_workers = int(np.ceil(search_iterations / num_iterations_per_worker))
+
+ # Sanity checks
+ if worker_number > max_workers:
+ raise ValueError(
+ "Worker number ({}) cannot be larger than the total number of workers!".format(max_workers)
+ )
+ if worker_number > search_iterations:
+ raise ValueError(
+ "Worker number ({}) cannot be larger than the max search iterations ({})!".format(
+ worker_number, search_iterations
+ )
+ )
+
+ print("*** Creating hyperparameter manager for worker {} ***".format(worker_number))
+
+ hyperparam_folder = os.path.join(root_model_folder, str(worker_number))
+ super().__init__(param_ranges, fixed_params, hyperparam_folder, override_w_fixed_params=True)
+
+ serialised_ranges_folder = os.path.join(root_model_folder, "hyperparams")
+ if clear_serialised_params:
+ print("Regenerating hyperparameter list")
+ if os.path.exists(serialised_ranges_folder):
+ shutil.rmtree(serialised_ranges_folder)
+
+ utils.create_folder_if_not_exist(serialised_ranges_folder)
+
+ self.serialised_ranges_path = os.path.join(serialised_ranges_folder, "ranges_{}.csv".format(search_iterations))
+ self.hyperparam_folder = hyperparam_folder # override
+ self.worker_num = worker_number
+ self.total_search_iterations = search_iterations
+ self.num_iterations_per_worker = num_iterations_per_worker
+ self.global_hyperparam_df = self.load_serialised_hyperparam_df()
+ self.worker_search_queue = self._get_worker_search_queue()
+
+ @property
+ def optimisation_completed(self):
+ return False if self.worker_search_queue else True
+
+ def get_next_parameters(self):
+ """Returns next dictionary of hyperparameters to optimise."""
+ param_name = self.worker_search_queue.pop()
+
+ params = self.global_hyperparam_df.loc[param_name, :].to_dict()
+
+ # Always override!
+ for k in self.fixed_params:
+ print("Overriding saved {}: {}".format(k, self.fixed_params[k]))
+
+ params[k] = self.fixed_params[k]
+
+ return params
+
+ def load_serialised_hyperparam_df(self):
+ """Loads serialsed hyperparameter ranges from file.
+
+ Returns:
+ DataFrame containing hyperparameter combinations.
+ """
+ print(
+ "Loading params for {} search iterations form {}".format(
+ self.total_search_iterations, self.serialised_ranges_path
+ )
+ )
+
+ if os.path.exists(self.serialised_ranges_folder):
+ df = pd.read_csv(self.serialised_ranges_path, index_col=0)
+ else:
+ print("Unable to load - regenerating serach ranges instead")
+ df = self.update_serialised_hyperparam_df()
+
+ return df
+
+ def update_serialised_hyperparam_df(self):
+ """Regenerates hyperparameter combinations and saves to file.
+
+ Returns:
+ DataFrame containing hyperparameter combinations.
+ """
+ search_df = self._generate_full_hyperparam_df()
+
+ print(
+ "Serialising params for {} search iterations to {}".format(
+ self.total_search_iterations, self.serialised_ranges_path
+ )
+ )
+
+ search_df.to_csv(self.serialised_ranges_path)
+
+ return search_df
+
+ def _generate_full_hyperparam_df(self):
+ """Generates actual hyperparameter combinations.
+
+ Returns:
+ DataFrame containing hyperparameter combinations.
+ """
+
+ np.random.seed(131) # for reproducibility of hyperparam list
+
+ name_list = []
+ param_list = []
+ for _ in range(self.total_search_iterations):
+ params = super().get_next_parameters(name_list)
+
+ name = self._get_name(params)
+
+ name_list.append(name)
+ param_list.append(params)
+
+ full_search_df = pd.DataFrame(param_list, index=name_list)
+
+ return full_search_df
+
+ def clear(self): # reset when cleared
+ """Clears results for hyperparameter manager and resets."""
+ super().clear()
+ self.worker_search_queue = self._get_worker_search_queue()
+
+ def load_results(self):
+ """Load results from file and queue parameter combinations to try.
+
+ Returns:
+ Boolean indicating if results were successfully loaded.
+ """
+ success = super().load_results()
+
+ if success:
+ self.worker_search_queue = self._get_worker_search_queue()
+
+ return success
+
+ def _get_worker_search_queue(self):
+ """Generates the queue of param combinations for current worker.
+
+ Returns:
+ Queue of hyperparameter combinations outstanding.
+ """
+ global_df = self.assign_worker_numbers(self.global_hyperparam_df)
+ worker_df = global_df[global_df["worker"] == self.worker_num]
+
+ left_overs = [s for s in worker_df.index if s not in self.results.columns]
+
+ return Deque(left_overs)
+
+ def assign_worker_numbers(self, df):
+ """Updates parameter combinations with the index of the worker used.
+
+ Args:
+ df: DataFrame of parameter combinations.
+
+ Returns:
+ Updated DataFrame with worker number.
+ """
+ output = df.copy()
+
+ n = self.total_search_iterations
+ batch_size = self.num_iterations_per_worker
+
+ max_worker_num = int(np.ceil(n / batch_size))
+
+ worker_idx = np.concatenate([np.tile(i + 1, self.num_iterations_per_worker) for i in range(max_worker_num)])
+
+ output["worker"] = worker_idx[: len(output)]
+
+ return output
diff --git a/examples/benchmarks/TFT/libs/tft_model.py b/examples/benchmarks/TFT/libs/tft_model.py
new file mode 100644
index 000000000..658bae60f
--- /dev/null
+++ b/examples/benchmarks/TFT/libs/tft_model.py
@@ -0,0 +1,1280 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Temporal Fusion Transformer Model.
+
+Contains the full TFT architecture and associated components. Defines functions
+for training, evaluation and prediction using simple Pandas Dataframe inputs.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gc
+import json
+import os
+import shutil
+
+import data_formatters.base
+import libs.utils as utils
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+
+# Layer definitions.
+concat = tf.keras.backend.concatenate
+stack = tf.keras.backend.stack
+K = tf.keras.backend
+Add = tf.keras.layers.Add
+LayerNorm = tf.keras.layers.LayerNormalization
+Dense = tf.keras.layers.Dense
+Multiply = tf.keras.layers.Multiply
+Dropout = tf.keras.layers.Dropout
+Activation = tf.keras.layers.Activation
+Lambda = tf.keras.layers.Lambda
+
+# Default input types.
+InputTypes = data_formatters.base.InputTypes
+
+
+# Layer utility functions.
+def linear_layer(size, activation=None, use_time_distributed=False, use_bias=True):
+ """Returns simple Keras linear layer.
+
+ Args:
+ size: Output size
+ activation: Activation function to apply if required
+ use_time_distributed: Whether to apply layer across time
+ use_bias: Whether bias should be included in layer
+ """
+ linear = tf.keras.layers.Dense(size, activation=activation, use_bias=use_bias)
+ if use_time_distributed:
+ linear = tf.keras.layers.TimeDistributed(linear)
+ return linear
+
+
+def apply_mlp(
+ inputs, hidden_size, output_size, output_activation=None, hidden_activation="tanh", use_time_distributed=False
+):
+ """Applies simple feed-forward network to an input.
+
+ Args:
+ inputs: MLP inputs
+ hidden_size: Hidden state size
+ output_size: Output size of MLP
+ output_activation: Activation function to apply on output
+ hidden_activation: Activation function to apply on input
+ use_time_distributed: Whether to apply across time
+
+ Returns:
+ Tensor for MLP outputs.
+ """
+ if use_time_distributed:
+ hidden = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_size, activation=hidden_activation))(
+ inputs
+ )
+ return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(output_size, activation=output_activation))(hidden)
+ else:
+ hidden = tf.keras.layers.Dense(hidden_size, activation=hidden_activation)(inputs)
+ return tf.keras.layers.Dense(output_size, activation=output_activation)(hidden)
+
+
+def apply_gating_layer(x, hidden_layer_size, dropout_rate=None, use_time_distributed=True, activation=None):
+ """Applies a Gated Linear Unit (GLU) to an input.
+
+ Args:
+ x: Input to gating layer
+ hidden_layer_size: Dimension of GLU
+ dropout_rate: Dropout rate to apply if any
+ use_time_distributed: Whether to apply across time
+ activation: Activation function to apply to the linear feature transform if
+ necessary
+
+ Returns:
+ Tuple of tensors for: (GLU output, gate)
+ """
+
+ if dropout_rate is not None:
+ x = tf.keras.layers.Dropout(dropout_rate)(x)
+
+ if use_time_distributed:
+ activation_layer = tf.keras.layers.TimeDistributed(
+ tf.keras.layers.Dense(hidden_layer_size, activation=activation)
+ )(x)
+ gated_layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid"))(x)
+ else:
+ activation_layer = tf.keras.layers.Dense(hidden_layer_size, activation=activation)(x)
+ gated_layer = tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid")(x)
+
+ return tf.keras.layers.Multiply()([activation_layer, gated_layer]), gated_layer
+
+
+def add_and_norm(x_list):
+ """Applies skip connection followed by layer normalisation.
+
+ Args:
+ x_list: List of inputs to sum for skip connection
+
+ Returns:
+ Tensor output from layer.
+ """
+ tmp = Add()(x_list)
+ tmp = LayerNorm()(tmp)
+ return tmp
+
+
+def gated_residual_network(
+ x,
+ hidden_layer_size,
+ output_size=None,
+ dropout_rate=None,
+ use_time_distributed=True,
+ additional_context=None,
+ return_gate=False,
+):
+ """Applies the gated residual network (GRN) as defined in paper.
+
+ Args:
+ x: Network inputs
+ hidden_layer_size: Internal state size
+ output_size: Size of output layer
+ dropout_rate: Dropout rate if dropout is applied
+ use_time_distributed: Whether to apply network across time dimension
+ additional_context: Additional context vector to use if relevant
+ return_gate: Whether to return GLU gate for diagnostic purposes
+
+ Returns:
+ Tuple of tensors for: (GRN output, GLU gate)
+ """
+
+ # Setup skip connection
+ if output_size is None:
+ output_size = hidden_layer_size
+ skip = x
+ else:
+ linear = Dense(output_size)
+ if use_time_distributed:
+ linear = tf.keras.layers.TimeDistributed(linear)
+ skip = linear(x)
+
+ # Apply feedforward network
+ hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(x)
+ if additional_context is not None:
+ hidden = hidden + linear_layer(
+ hidden_layer_size, activation=None, use_time_distributed=use_time_distributed, use_bias=False
+ )(additional_context)
+ hidden = tf.keras.layers.Activation("elu")(hidden)
+ hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(hidden)
+
+ gating_layer, gate = apply_gating_layer(
+ hidden, output_size, dropout_rate=dropout_rate, use_time_distributed=use_time_distributed, activation=None
+ )
+
+ if return_gate:
+ return add_and_norm([skip, gating_layer]), gate
+ else:
+ return add_and_norm([skip, gating_layer])
+
+
+# Attention Components.
+def get_decoder_mask(self_attn_inputs):
+ """Returns causal mask to apply for self-attention layer.
+
+ Args:
+ self_attn_inputs: Inputs to self attention layer to determine mask shape
+ """
+ len_s = tf.shape(self_attn_inputs)[1]
+ bs = tf.shape(self_attn_inputs)[:1]
+ mask = K.cumsum(tf.eye(len_s, batch_shape=bs), 1)
+ return mask
+
+
+class ScaledDotProductAttention:
+ """Defines scaled dot product attention layer.
+
+ Attributes:
+ dropout: Dropout rate to use
+ activation: Normalisation function for scaled dot product attention (e.g.
+ softmax by default)
+ """
+
+ def __init__(self, attn_dropout=0.0):
+ self.dropout = Dropout(attn_dropout)
+ self.activation = Activation("softmax")
+
+ def __call__(self, q, k, v, mask):
+ """Applies scaled dot product attention.
+
+ Args:
+ q: Queries
+ k: Keys
+ v: Values
+ mask: Masking if required -- sets softmax to very large value
+
+ Returns:
+ Tuple of (layer outputs, attention weights)
+ """
+ temper = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype="float32"))
+ attn = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]) / temper)([q, k]) # shape=(batch, q, k)
+ if mask is not None:
+ mmask = Lambda(lambda x: (-1e9) * (1.0 - K.cast(x, "float32")))(mask) # setting to infinity
+ attn = Add()([attn, mmask])
+ attn = self.activation(attn)
+ attn = self.dropout(attn)
+ output = Lambda(lambda x: K.batch_dot(x[0], x[1]))([attn, v])
+ return output, attn
+
+
+class InterpretableMultiHeadAttention:
+ """Defines interpretable multi-head attention layer.
+
+ Attributes:
+ n_head: Number of heads
+ d_k: Key/query dimensionality per head
+ d_v: Value dimensionality
+ dropout: Dropout rate to apply
+ qs_layers: List of queries across heads
+ ks_layers: List of keys across heads
+ vs_layers: List of values across heads
+ attention: Scaled dot product attention layer
+ w_o: Output weight matrix to project internal state to the original TFT
+ state size
+ """
+
+ def __init__(self, n_head, d_model, dropout):
+ """Initialises layer.
+
+ Args:
+ n_head: Number of heads
+ d_model: TFT state dimensionality
+ dropout: Dropout discard rate
+ """
+
+ self.n_head = n_head
+ self.d_k = self.d_v = d_k = d_v = d_model // n_head
+ self.dropout = dropout
+
+ self.qs_layers = []
+ self.ks_layers = []
+ self.vs_layers = []
+
+ # Use same value layer to facilitate interp
+ vs_layer = Dense(d_v, use_bias=False)
+
+ for _ in range(n_head):
+ self.qs_layers.append(Dense(d_k, use_bias=False))
+ self.ks_layers.append(Dense(d_k, use_bias=False))
+ self.vs_layers.append(vs_layer) # use same vs_layer
+
+ self.attention = ScaledDotProductAttention()
+ self.w_o = Dense(d_model, use_bias=False)
+
+ def __call__(self, q, k, v, mask=None):
+ """Applies interpretable multihead attention.
+
+ Using T to denote the number of time steps fed into the transformer.
+
+ Args:
+ q: Query tensor of shape=(?, T, d_model)
+ k: Key of shape=(?, T, d_model)
+ v: Values of shape=(?, T, d_model)
+ mask: Masking if required with shape=(?, T, T)
+
+ Returns:
+ Tuple of (layer outputs, attention weights)
+ """
+ n_head = self.n_head
+
+ heads = []
+ attns = []
+ for i in range(n_head):
+ qs = self.qs_layers[i](q)
+ ks = self.ks_layers[i](k)
+ vs = self.vs_layers[i](v)
+ head, attn = self.attention(qs, ks, vs, mask)
+
+ head_dropout = Dropout(self.dropout)(head)
+ heads.append(head_dropout)
+ attns.append(attn)
+ head = K.stack(heads) if n_head > 1 else heads[0]
+ attn = K.stack(attns)
+
+ outputs = K.mean(head, axis=0) if n_head > 1 else head
+ outputs = self.w_o(outputs)
+ outputs = Dropout(self.dropout)(outputs) # output dropout
+
+ return outputs, attn
+
+
+class TFTDataCache(object):
+ """Caches data for the TFT."""
+
+ _data_cache = {}
+
+ @classmethod
+ def update(cls, data, key):
+ """Updates cached data.
+
+ Args:
+ data: Source to update
+ key: Key to dictionary location
+ """
+ cls._data_cache[key] = data
+
+ @classmethod
+ def get(cls, key):
+ """Returns data stored at key location."""
+ return cls._data_cache[key].copy()
+
+ @classmethod
+ def contains(cls, key):
+ """Retuns boolean indicating whether key is present in cache."""
+
+ return key in cls._data_cache
+
+
+# TFT model definitions.
+class TemporalFusionTransformer(object):
+ """Defines Temporal Fusion Transformer.
+
+ Attributes:
+ name: Name of model
+ time_steps: Total number of input time steps per forecast date (i.e. Width
+ of Temporal fusion decoder N)
+ input_size: Total number of inputs
+ output_size: Total number of outputs
+ category_counts: Number of categories per categorical variable
+ n_multiprocessing_workers: Number of workers to use for parallel
+ computations
+ column_definition: List of tuples of (string, DataType, InputType) that
+ define each column
+ quantiles: Quantiles to forecast for TFT
+ use_cudnn: Whether to use Keras CuDNNLSTM or standard LSTM layers
+ hidden_layer_size: Internal state size of TFT
+ dropout_rate: Dropout discard rate
+ max_gradient_norm: Maximum norm for gradient clipping
+ learning_rate: Initial learning rate of ADAM optimizer
+ minibatch_size: Size of minibatches for training
+ num_epochs: Maximum number of epochs for training
+ early_stopping_patience: Maximum number of iterations of non-improvement
+ before early stopping kicks in
+ num_encoder_steps: Size of LSTM encoder -- i.e. number of past time steps
+ before forecast date to use
+ num_stacks: Number of self-attention layers to apply (default is 1 for basic
+ TFT)
+ num_heads: Number of heads for interpretable mulit-head attention
+ model: Keras model for TFT
+ """
+
+ def __init__(self, raw_params, use_cudnn=False):
+ """Builds TFT from parameters.
+
+ Args:
+ raw_params: Parameters to define TFT
+ use_cudnn: Whether to use CUDNN GPU optimised LSTM
+ """
+
+ self.name = self.__class__.__name__
+
+ params = dict(raw_params) # copy locally
+
+ # Data parameters
+ self.time_steps = int(params["total_time_steps"])
+ self.input_size = int(params["input_size"])
+ self.output_size = int(params["output_size"])
+ self.category_counts = json.loads(str(params["category_counts"]))
+ self.n_multiprocessing_workers = int(params["multiprocessing_workers"])
+
+ # Relevant indices for TFT
+ self._input_obs_loc = json.loads(str(params["input_obs_loc"]))
+ self._static_input_loc = json.loads(str(params["static_input_loc"]))
+ self._known_regular_input_idx = json.loads(str(params["known_regular_inputs"]))
+ self._known_categorical_input_idx = json.loads(str(params["known_categorical_inputs"]))
+
+ self.column_definition = params["column_definition"]
+
+ # Network params
+ self.quantiles = [0.1, 0.5, 0.9]
+ self.use_cudnn = use_cudnn # Whether to use GPU optimised LSTM
+ self.hidden_layer_size = int(params["hidden_layer_size"])
+ self.dropout_rate = float(params["dropout_rate"])
+ self.max_gradient_norm = float(params["max_gradient_norm"])
+ self.learning_rate = float(params["learning_rate"])
+ self.minibatch_size = int(params["minibatch_size"])
+ self.num_epochs = int(params["num_epochs"])
+ self.early_stopping_patience = int(params["early_stopping_patience"])
+
+ self.num_encoder_steps = int(params["num_encoder_steps"])
+ self.num_stacks = int(params["stack_size"])
+ self.num_heads = int(params["num_heads"])
+
+ # Serialisation options
+ self._temp_folder = os.path.join(params["model_folder"], "tmp")
+ self.reset_temp_folder()
+
+ # Extra components to store Tensorflow nodes for attention computations
+ self._input_placeholder = None
+ self._attention_components = None
+ self._prediction_parts = None
+
+ print("*** {} params ***".format(self.name))
+ for k in params:
+ print("# {} = {}".format(k, params[k]))
+
+ # Build model
+ self.model = self.build_model()
+
+ def get_tft_embeddings(self, all_inputs):
+ """Transforms raw inputs to embeddings.
+
+ Applies linear transformation onto continuous variables and uses embeddings
+ for categorical variables.
+
+ Args:
+ all_inputs: Inputs to transform
+
+ Returns:
+ Tensors for transformed inputs.
+ """
+
+ time_steps = self.time_steps
+
+ # Sanity checks
+ for i in self._known_regular_input_idx:
+ if i in self._input_obs_loc:
+ raise ValueError("Observation cannot be known a priori!")
+ for i in self._input_obs_loc:
+ if i in self._static_input_loc:
+ raise ValueError("Observation cannot be static!")
+
+ if all_inputs.get_shape().as_list()[-1] != self.input_size:
+ raise ValueError(
+ "Illegal number of inputs! Inputs observed={}, expected={}".format(
+ all_inputs.get_shape().as_list()[-1], self.input_size
+ )
+ )
+
+ num_categorical_variables = len(self.category_counts)
+ num_regular_variables = self.input_size - num_categorical_variables
+
+ embedding_sizes = [self.hidden_layer_size for i, size in enumerate(self.category_counts)]
+
+ embeddings = []
+ for i in range(num_categorical_variables):
+
+ embedding = tf.keras.Sequential(
+ [
+ tf.keras.layers.InputLayer([time_steps]),
+ tf.keras.layers.Embedding(
+ self.category_counts[i], embedding_sizes[i], input_length=time_steps, dtype=tf.float32
+ ),
+ ]
+ )
+ embeddings.append(embedding)
+
+ regular_inputs, categorical_inputs = (
+ all_inputs[:, :, :num_regular_variables],
+ all_inputs[:, :, num_regular_variables:],
+ )
+
+ embedded_inputs = [embeddings[i](categorical_inputs[Ellipsis, i]) for i in range(num_categorical_variables)]
+
+ # Static inputs
+ if self._static_input_loc:
+ static_inputs = [
+ tf.keras.layers.Dense(self.hidden_layer_size)(regular_inputs[:, 0, i : i + 1])
+ for i in range(num_regular_variables)
+ if i in self._static_input_loc
+ ] + [
+ embedded_inputs[i][:, 0, :]
+ for i in range(num_categorical_variables)
+ if i + num_regular_variables in self._static_input_loc
+ ]
+ static_inputs = tf.keras.backend.stack(static_inputs, axis=1)
+
+ else:
+ static_inputs = None
+
+ def convert_real_to_embedding(x):
+ """Applies linear transformation for time-varying inputs."""
+ return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.hidden_layer_size))(x)
+
+ # Targets
+ obs_inputs = tf.keras.backend.stack(
+ [convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1]) for i in self._input_obs_loc], axis=-1
+ )
+
+ # Observed (a prioir unknown) inputs
+ wired_embeddings = []
+ for i in range(num_categorical_variables):
+ if i not in self._known_categorical_input_idx and i + num_regular_variables not in self._input_obs_loc:
+ e = embeddings[i](categorical_inputs[:, :, i])
+ wired_embeddings.append(e)
+
+ unknown_inputs = []
+ for i in range(regular_inputs.shape[-1]):
+ if i not in self._known_regular_input_idx and i not in self._input_obs_loc:
+ e = convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1])
+ unknown_inputs.append(e)
+
+ if unknown_inputs + wired_embeddings:
+ unknown_inputs = tf.keras.backend.stack(unknown_inputs + wired_embeddings, axis=-1)
+ else:
+ unknown_inputs = None
+
+ # A priori known inputs
+ known_regular_inputs = [
+ convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1])
+ for i in self._known_regular_input_idx
+ if i not in self._static_input_loc
+ ]
+ known_categorical_inputs = [
+ embedded_inputs[i]
+ for i in self._known_categorical_input_idx
+ if i + num_regular_variables not in self._static_input_loc
+ ]
+
+ known_combined_layer = tf.keras.backend.stack(known_regular_inputs + known_categorical_inputs, axis=-1)
+
+ return unknown_inputs, known_combined_layer, obs_inputs, static_inputs
+
+ def _get_single_col_by_type(self, input_type):
+ """Returns name of single column for input type."""
+
+ return utils.get_single_col_by_input_type(input_type, self.column_definition)
+
+ def training_data_cached(self):
+ """Returns boolean indicating if training data has been cached."""
+
+ return TFTDataCache.contains("train") and TFTDataCache.contains("valid")
+
+ def cache_batched_data(self, data, cache_key, num_samples=-1):
+ """Batches and caches data once for using during training.
+
+ Args:
+ data: Data to batch and cache
+ cache_key: Key used for cache
+ num_samples: Maximum number of samples to extract (-1 to use all data)
+ """
+
+ if num_samples > 0:
+ TFTDataCache.update(self._batch_sampled_data(data, max_samples=num_samples), cache_key)
+ else:
+ TFTDataCache.update(self._batch_data(data), cache_key)
+
+ print('Cached data "{}" updated'.format(cache_key))
+
+ def _batch_sampled_data(self, data, max_samples):
+ """Samples segments into a compatible format.
+
+ Args:
+ data: Sources data to sample and batch
+ max_samples: Maximum number of samples in batch
+
+ Returns:
+ Dictionary of batched data with the maximum samples specified.
+ """
+
+ if max_samples < 1:
+ raise ValueError("Illegal number of samples specified! samples={}".format(max_samples))
+
+ id_col = self._get_single_col_by_type(InputTypes.ID)
+ time_col = self._get_single_col_by_type(InputTypes.TIME)
+
+ data.sort_values(by=[id_col, time_col], inplace=True)
+
+ print("Getting valid sampling locations.")
+ valid_sampling_locations = []
+ split_data_map = {}
+ for identifier, df in data.groupby(id_col):
+ print("Getting locations for {}".format(identifier))
+ num_entries = len(df)
+ if num_entries >= self.time_steps:
+ valid_sampling_locations += [
+ (identifier, self.time_steps + i) for i in range(num_entries - self.time_steps + 1)
+ ]
+ split_data_map[identifier] = df
+
+ inputs = np.zeros((max_samples, self.time_steps, self.input_size))
+ outputs = np.zeros((max_samples, self.time_steps, self.output_size))
+ time = np.empty((max_samples, self.time_steps, 1), dtype=object)
+ identifiers = np.empty((max_samples, self.time_steps, 1), dtype=object)
+
+ if max_samples > 0 and len(valid_sampling_locations) > max_samples:
+ print("Extracting {} samples...".format(max_samples))
+ ranges = [
+ valid_sampling_locations[i]
+ for i in np.random.choice(len(valid_sampling_locations), max_samples, replace=False)
+ ]
+ else:
+ print("Max samples={} exceeds # available segments={}".format(max_samples, len(valid_sampling_locations)))
+ ranges = valid_sampling_locations
+
+ id_col = self._get_single_col_by_type(InputTypes.ID)
+ time_col = self._get_single_col_by_type(InputTypes.TIME)
+ target_col = self._get_single_col_by_type(InputTypes.TARGET)
+ input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
+
+ for i, tup in enumerate(ranges):
+ if (i + 1 % 1000) == 0:
+ print(i + 1, "of", max_samples, "samples done...")
+ identifier, start_idx = tup
+ sliced = split_data_map[identifier].iloc[start_idx - self.time_steps : start_idx]
+ inputs[i, :, :] = sliced[input_cols]
+ outputs[i, :, :] = sliced[[target_col]]
+ time[i, :, 0] = sliced[time_col]
+ identifiers[i, :, 0] = sliced[id_col]
+
+ sampled_data = {
+ "inputs": inputs,
+ "outputs": outputs[:, self.num_encoder_steps :, :],
+ "active_entries": np.ones_like(outputs[:, self.num_encoder_steps :, :]),
+ "time": time,
+ "identifier": identifiers,
+ }
+
+ return sampled_data
+
+ def _batch_data(self, data):
+ """Batches data for training.
+
+ Converts raw dataframe from a 2-D tabular format to a batched 3-D array
+ to feed into Keras model.
+
+ Args:
+ data: DataFrame to batch
+
+ Returns:
+ Batched Numpy array with shape=(?, self.time_steps, self.input_size)
+ """
+
+ # Functions.
+ def _batch_single_entity(input_data):
+ time_steps = len(input_data)
+ lags = self.time_steps
+ x = input_data.values
+ if time_steps >= lags:
+ return np.stack([x[i : time_steps - (lags - 1) + i, :] for i in range(lags)], axis=1)
+
+ else:
+ return None
+
+ id_col = self._get_single_col_by_type(InputTypes.ID)
+ time_col = self._get_single_col_by_type(InputTypes.TIME)
+ target_col = self._get_single_col_by_type(InputTypes.TARGET)
+ input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
+
+ data_map = {}
+ for _, sliced in data.groupby(id_col):
+
+ col_mappings = {"identifier": [id_col], "time": [time_col], "outputs": [target_col], "inputs": input_cols}
+
+ for k in col_mappings:
+ cols = col_mappings[k]
+ arr = _batch_single_entity(sliced[cols].copy())
+
+ if k not in data_map:
+ data_map[k] = [arr]
+ else:
+ data_map[k].append(arr)
+
+ # Combine all data
+ for k in data_map:
+ # Wendi: Avoid returning None when the length is not enough
+ data_map[k] = np.concatenate([i for i in data_map[k] if i is not None], axis=0)
+
+ # Shorten target so we only get decoder steps
+ data_map["outputs"] = data_map["outputs"][:, self.num_encoder_steps :, :]
+
+ active_entries = np.ones_like(data_map["outputs"])
+ if "active_entries" not in data_map:
+ data_map["active_entries"] = active_entries
+ else:
+ data_map["active_entries"].append(active_entries)
+
+ return data_map
+
+ def _get_active_locations(self, x):
+ """Formats sample weights for Keras training."""
+ return (np.sum(x, axis=-1) > 0.0) * 1.0
+
+ def _build_base_graph(self):
+ """Returns graph defining layers of the TFT."""
+
+ # Size definitions.
+ time_steps = self.time_steps
+ combined_input_size = self.input_size
+ encoder_steps = self.num_encoder_steps
+
+ # Inputs.
+ all_inputs = tf.keras.layers.Input(
+ shape=(
+ time_steps,
+ combined_input_size,
+ )
+ )
+
+ unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)
+
+ # Isolate known and observed historical inputs.
+ if unknown_inputs is not None:
+ historical_inputs = concat(
+ [
+ unknown_inputs[:, :encoder_steps, :],
+ known_combined_layer[:, :encoder_steps, :],
+ obs_inputs[:, :encoder_steps, :],
+ ],
+ axis=-1,
+ )
+ else:
+ historical_inputs = concat(
+ [known_combined_layer[:, :encoder_steps, :], obs_inputs[:, :encoder_steps, :]], axis=-1
+ )
+
+ # Isolate only known future inputs.
+ future_inputs = known_combined_layer[:, encoder_steps:, :]
+
+ def static_combine_and_mask(embedding):
+ """Applies variable selection network to static inputs.
+
+ Args:
+ embedding: Transformed static inputs
+
+ Returns:
+ Tensor output for variable selection network
+ """
+
+ # Add temporal features
+ _, num_static, _ = embedding.get_shape().as_list()
+
+ flatten = tf.keras.layers.Flatten()(embedding)
+
+ # Nonlinear transformation with gated residual network.
+ mlp_outputs = gated_residual_network(
+ flatten,
+ self.hidden_layer_size,
+ output_size=num_static,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=False,
+ additional_context=None,
+ )
+
+ sparse_weights = tf.keras.layers.Activation("softmax")(mlp_outputs)
+ sparse_weights = K.expand_dims(sparse_weights, axis=-1)
+
+ trans_emb_list = []
+ for i in range(num_static):
+ e = gated_residual_network(
+ embedding[:, i : i + 1, :],
+ self.hidden_layer_size,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=False,
+ )
+ trans_emb_list.append(e)
+
+ transformed_embedding = concat(trans_emb_list, axis=1)
+
+ combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding])
+
+ static_vec = K.sum(combined, axis=1)
+
+ return static_vec, sparse_weights
+
+ static_encoder, static_weights = static_combine_and_mask(static_inputs)
+
+ static_context_variable_selection = gated_residual_network(
+ static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
+ )
+ static_context_enrichment = gated_residual_network(
+ static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
+ )
+ static_context_state_h = gated_residual_network(
+ static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
+ )
+ static_context_state_c = gated_residual_network(
+ static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False
+ )
+
+ def lstm_combine_and_mask(embedding):
+ """Apply temporal variable selection networks.
+
+ Args:
+ embedding: Transformed inputs.
+
+ Returns:
+ Processed tensor outputs.
+ """
+
+ # Add temporal features
+ _, time_steps, embedding_dim, num_inputs = embedding.get_shape().as_list()
+
+ flatten = K.reshape(embedding, [-1, time_steps, embedding_dim * num_inputs])
+
+ expanded_static_context = K.expand_dims(static_context_variable_selection, axis=1)
+
+ # Variable selection weights
+ mlp_outputs, static_gate = gated_residual_network(
+ flatten,
+ self.hidden_layer_size,
+ output_size=num_inputs,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=True,
+ additional_context=expanded_static_context,
+ return_gate=True,
+ )
+
+ sparse_weights = tf.keras.layers.Activation("softmax")(mlp_outputs)
+ sparse_weights = tf.expand_dims(sparse_weights, axis=2)
+
+ # Non-linear Processing & weight application
+ trans_emb_list = []
+ for i in range(num_inputs):
+ grn_output = gated_residual_network(
+ embedding[Ellipsis, i],
+ self.hidden_layer_size,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=True,
+ )
+ trans_emb_list.append(grn_output)
+
+ transformed_embedding = stack(trans_emb_list, axis=-1)
+
+ combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding])
+ temporal_ctx = K.sum(combined, axis=-1)
+
+ return temporal_ctx, sparse_weights, static_gate
+
+ historical_features, historical_flags, _ = lstm_combine_and_mask(historical_inputs)
+ future_features, future_flags, _ = lstm_combine_and_mask(future_inputs)
+
+ # LSTM layer
+ def get_lstm(return_state):
+ """Returns LSTM cell initialized with default parameters."""
+ if self.use_cudnn:
+ lstm = tf.keras.layers.CuDNNLSTM(
+ self.hidden_layer_size,
+ return_sequences=True,
+ return_state=return_state,
+ stateful=False,
+ )
+ else:
+ lstm = tf.keras.layers.LSTM(
+ self.hidden_layer_size,
+ return_sequences=True,
+ return_state=return_state,
+ stateful=False,
+ # Additional params to ensure LSTM matches CuDNN, See TF 2.0 :
+ # (https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM)
+ activation="tanh",
+ recurrent_activation="sigmoid",
+ recurrent_dropout=0,
+ unroll=False,
+ use_bias=True,
+ )
+ return lstm
+
+ history_lstm, state_h, state_c = get_lstm(return_state=True)(
+ historical_features, initial_state=[static_context_state_h, static_context_state_c]
+ )
+
+ future_lstm = get_lstm(return_state=False)(future_features, initial_state=[state_h, state_c])
+
+ lstm_layer = concat([history_lstm, future_lstm], axis=1)
+
+ # Apply gated skip connection
+ input_embeddings = concat([historical_features, future_features], axis=1)
+
+ lstm_layer, _ = apply_gating_layer(lstm_layer, self.hidden_layer_size, self.dropout_rate, activation=None)
+ temporal_feature_layer = add_and_norm([lstm_layer, input_embeddings])
+
+ # Static enrichment layers
+ expanded_static_context = K.expand_dims(static_context_enrichment, axis=1)
+ enriched, _ = gated_residual_network(
+ temporal_feature_layer,
+ self.hidden_layer_size,
+ dropout_rate=self.dropout_rate,
+ use_time_distributed=True,
+ additional_context=expanded_static_context,
+ return_gate=True,
+ )
+
+ # Decoder self attention
+ self_attn_layer = InterpretableMultiHeadAttention(
+ self.num_heads, self.hidden_layer_size, dropout=self.dropout_rate
+ )
+
+ mask = get_decoder_mask(enriched)
+ x, self_att = self_attn_layer(enriched, enriched, enriched, mask=mask)
+
+ x, _ = apply_gating_layer(x, self.hidden_layer_size, dropout_rate=self.dropout_rate, activation=None)
+ x = add_and_norm([x, enriched])
+
+ # Nonlinear processing on outputs
+ decoder = gated_residual_network(
+ x, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=True
+ )
+
+ # Final skip connection
+ decoder, _ = apply_gating_layer(decoder, self.hidden_layer_size, activation=None)
+ transformer_layer = add_and_norm([decoder, temporal_feature_layer])
+
+ # Attention components for explainability
+ attention_components = {
+ # Temporal attention weights
+ "decoder_self_attn": self_att,
+ # Static variable selection weights
+ "static_flags": static_weights[Ellipsis, 0],
+ # Variable selection weights of past inputs
+ "historical_flags": historical_flags[Ellipsis, 0, :],
+ # Variable selection weights of future inputs
+ "future_flags": future_flags[Ellipsis, 0, :],
+ }
+
+ return transformer_layer, all_inputs, attention_components
+
+ def build_model(self):
+ """Build model and defines training losses.
+
+ Returns:
+ Fully defined Keras model.
+ """
+
+ with tf.variable_scope(self.name):
+
+ transformer_layer, all_inputs, attention_components = self._build_base_graph()
+
+ outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.output_size * len(self.quantiles)))(
+ transformer_layer[Ellipsis, self.num_encoder_steps :, :]
+ )
+
+ self._attention_components = attention_components
+
+ adam = tf.keras.optimizers.Adam(lr=self.learning_rate, clipnorm=self.max_gradient_norm)
+
+ model = tf.keras.Model(inputs=all_inputs, outputs=outputs)
+
+ print(model.summary())
+
+ valid_quantiles = self.quantiles
+ output_size = self.output_size
+
+ class QuantileLossCalculator(object):
+ """Computes the combined quantile loss for prespecified quantiles.
+
+ Attributes:
+ quantiles: Quantiles to compute losses
+ """
+
+ def __init__(self, quantiles):
+ """Initializes computer with quantiles for loss calculations.
+
+ Args:
+ quantiles: Quantiles to use for computations.
+ """
+ self.quantiles = quantiles
+
+ def quantile_loss(self, a, b):
+ """Returns quantile loss for specified quantiles.
+
+ Args:
+ a: Targets
+ b: Predictions
+ """
+ quantiles_used = set(self.quantiles)
+
+ loss = 0.0
+ for i, quantile in enumerate(valid_quantiles):
+ if quantile in quantiles_used:
+ loss += utils.tensorflow_quantile_loss(
+ a[Ellipsis, output_size * i : output_size * (i + 1)],
+ b[Ellipsis, output_size * i : output_size * (i + 1)],
+ quantile,
+ )
+ return loss
+
+ quantile_loss = QuantileLossCalculator(valid_quantiles).quantile_loss
+
+ model.compile(loss=quantile_loss, optimizer=adam, sample_weight_mode="temporal")
+
+ self._input_placeholder = all_inputs
+
+ return model
+
+ def fit(self, train_df=None, valid_df=None):
+ """Fits deep neural network for given training and validation data.
+
+ Args:
+ train_df: DataFrame for training data
+ valid_df: DataFrame for validation data
+ """
+
+ print("*** Fitting {} ***".format(self.name))
+
+ # Add relevant callbacks
+ callbacks = [
+ tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=self.early_stopping_patience, min_delta=1e-4),
+ tf.keras.callbacks.ModelCheckpoint(
+ filepath=self.get_keras_saved_path(self._temp_folder),
+ monitor="val_loss",
+ save_best_only=True,
+ save_weights_only=True,
+ ),
+ tf.keras.callbacks.TerminateOnNaN(),
+ ]
+
+ print("Getting batched_data")
+ if train_df is None:
+ print("Using cached training data")
+ train_data = TFTDataCache.get("train")
+ else:
+ train_data = self._batch_data(train_df)
+
+ if valid_df is None:
+ print("Using cached validation data")
+ valid_data = TFTDataCache.get("valid")
+ else:
+ valid_data = self._batch_data(valid_df)
+
+ print("Using keras standard fit")
+
+ def _unpack(data):
+ return data["inputs"], data["outputs"], self._get_active_locations(data["active_entries"])
+
+ # Unpack without sample weights
+ data, labels, active_flags = _unpack(train_data)
+ val_data, val_labels, val_flags = _unpack(valid_data)
+
+ all_callbacks = callbacks
+
+ self.model.fit(
+ x=data,
+ y=np.concatenate([labels, labels, labels], axis=-1),
+ sample_weight=active_flags,
+ epochs=self.num_epochs,
+ batch_size=self.minibatch_size,
+ validation_data=(val_data, np.concatenate([val_labels, val_labels, val_labels], axis=-1), val_flags),
+ callbacks=all_callbacks,
+ shuffle=True,
+ use_multiprocessing=True,
+ workers=self.n_multiprocessing_workers,
+ )
+
+ # Load best checkpoint again
+ tmp_checkpont = self.get_keras_saved_path(self._temp_folder)
+ if os.path.exists(tmp_checkpont):
+ self.load(self._temp_folder, use_keras_loadings=True)
+
+ else:
+ print("Cannot load from {}, skipping ...".format(self._temp_folder))
+
+ def evaluate(self, data=None, eval_metric="loss"):
+ """Applies evaluation metric to the training data.
+
+ Args:
+ data: Dataframe for evaluation
+ eval_metric: Evaluation metic to return, based on model definition.
+
+ Returns:
+ Computed evaluation loss.
+ """
+
+ if data is None:
+ print("Using cached validation data")
+ raw_data = TFTDataCache.get("valid")
+ else:
+ raw_data = self._batch_data(data)
+
+ inputs = raw_data["inputs"]
+ outputs = raw_data["outputs"]
+ active_entries = self._get_active_locations(raw_data["active_entries"])
+
+ metric_values = self.model.evaluate(
+ x=inputs,
+ y=np.concatenate([outputs, outputs, outputs], axis=-1),
+ sample_weight=active_entries,
+ workers=16,
+ use_multiprocessing=True,
+ )
+
+ metrics = pd.Series(metric_values, self.model.metrics_names)
+
+ return metrics[eval_metric]
+
+ def predict(self, df, return_targets=False):
+ """Computes predictions for a given input dataset.
+
+ Args:
+ df: Input dataframe
+ return_targets: Whether to also return outputs aligned with predictions to
+ faciliate evaluation
+
+ Returns:
+ Input dataframe or tuple of (input dataframe, algined output dataframe).
+ """
+
+ data = self._batch_data(df)
+
+ inputs = data["inputs"]
+ time = data["time"]
+ identifier = data["identifier"]
+ outputs = data["outputs"]
+
+ combined = self.model.predict(inputs, workers=16, use_multiprocessing=True, batch_size=self.minibatch_size)
+
+ # Format output_csv
+ if self.output_size != 1:
+ raise NotImplementedError("Current version only supports 1D targets!")
+
+ def format_outputs(prediction):
+ """Returns formatted dataframes for prediction."""
+
+ flat_prediction = pd.DataFrame(
+ prediction[:, :, 0], columns=["t+{}".format(i) for i in range(self.time_steps - self.num_encoder_steps)]
+ )
+ cols = list(flat_prediction.columns)
+ flat_prediction["forecast_time"] = time[:, self.num_encoder_steps - 1, 0]
+ flat_prediction["identifier"] = identifier[:, 0, 0]
+
+ # Arrange in order
+ return flat_prediction[["forecast_time", "identifier"] + cols]
+
+ # Extract predictions for each quantile into different entries
+ process_map = {
+ "p{}".format(int(q * 100)): combined[Ellipsis, i * self.output_size : (i + 1) * self.output_size]
+ for i, q in enumerate(self.quantiles)
+ }
+
+ if return_targets:
+ # Add targets if relevant
+ process_map["targets"] = outputs
+
+ return {k: format_outputs(process_map[k]) for k in process_map}
+
+ def get_attention(self, df):
+ """Computes TFT attention weights for a given dataset.
+
+ Args:
+ df: Input dataframe
+
+ Returns:
+ Dictionary of numpy arrays for temporal attention weights and variable
+ selection weights, along with their identifiers and time indices
+ """
+
+ data = self._batch_data(df)
+ inputs = data["inputs"]
+ identifiers = data["identifier"]
+ time = data["time"]
+
+ def get_batch_attention_weights(input_batch):
+ """Returns weights for a given minibatch of data."""
+ input_placeholder = self._input_placeholder
+ attention_weights = {}
+ for k in self._attention_components:
+ attention_weight = tf.keras.backend.get_session().run(
+ self._attention_components[k], {input_placeholder: input_batch.astype(np.float32)}
+ )
+ attention_weights[k] = attention_weight
+ return attention_weights
+
+ # Compute number of batches
+ batch_size = self.minibatch_size
+ n = inputs.shape[0]
+ num_batches = n // batch_size
+ if n - (num_batches * batch_size) > 0:
+ num_batches += 1
+
+ # Split up inputs into batches
+ batched_inputs = [inputs[i * batch_size : (i + 1) * batch_size, Ellipsis] for i in range(num_batches)]
+
+ # Get attention weights, while avoiding large memory increases
+ attention_by_batch = [get_batch_attention_weights(batch) for batch in batched_inputs]
+ attention_weights = {}
+ for k in self._attention_components:
+ attention_weights[k] = []
+ for batch_weights in attention_by_batch:
+ attention_weights[k].append(batch_weights[k])
+
+ if len(attention_weights[k][0].shape) == 4:
+ tmp = np.concatenate(attention_weights[k], axis=1)
+ else:
+ tmp = np.concatenate(attention_weights[k], axis=0)
+
+ del attention_weights[k]
+ gc.collect()
+ attention_weights[k] = tmp
+
+ attention_weights["identifiers"] = identifiers[:, 0, 0]
+ attention_weights["time"] = time[:, :, 0]
+
+ return attention_weights
+
+ # Serialisation.
+ def reset_temp_folder(self):
+ """Deletes and recreates folder with temporary Keras training outputs."""
+ print("Resetting temp folder...")
+ utils.create_folder_if_not_exist(self._temp_folder)
+ shutil.rmtree(self._temp_folder)
+ os.makedirs(self._temp_folder)
+
+ def get_keras_saved_path(self, model_folder):
+ """Returns path to keras checkpoint."""
+ return os.path.join(model_folder, "{}.check".format(self.name))
+
+ def save(self, model_folder):
+ """Saves optimal TFT weights.
+
+ Args:
+ model_folder: Location to serialze model.
+ """
+ # Allows for direct serialisation of tensorflow variables to avoid spurious
+ # issue with Keras that leads to different performance evaluation results
+ # when model is reloaded (https://github.com/keras-team/keras/issues/4875).
+
+ utils.save(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name)
+
+ def load(self, model_folder, use_keras_loadings=False):
+ """Loads TFT weights.
+
+ Args:
+ model_folder: Folder containing serialized models.
+ use_keras_loadings: Whether to load from Keras checkpoint.
+
+ Returns:
+
+ """
+ if use_keras_loadings:
+ # Loads temporary Keras model saved during training.
+ serialisation_path = self.get_keras_saved_path(model_folder)
+ print("Loading model from {}".format(serialisation_path))
+ self.model.load_weights(serialisation_path)
+ else:
+ # Loads tensorflow graph for optimal models.
+ utils.load(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name)
+
+ @classmethod
+ def get_hyperparm_choices(cls):
+ """Returns hyperparameter ranges for random search."""
+ return {
+ "dropout_rate": [0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9],
+ "hidden_layer_size": [10, 20, 40, 80, 160, 240, 320],
+ "minibatch_size": [64, 128, 256],
+ "learning_rate": [1e-4, 1e-3, 1e-2],
+ "max_gradient_norm": [0.01, 1.0, 100.0],
+ "num_heads": [1, 4],
+ "stack_size": [1],
+ }
diff --git a/examples/benchmarks/TFT/libs/utils.py b/examples/benchmarks/TFT/libs/utils.py
new file mode 100644
index 000000000..4682434d6
--- /dev/null
+++ b/examples/benchmarks/TFT/libs/utils.py
@@ -0,0 +1,224 @@
+# coding=utf-8
+# Copyright 2020 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Generic helper functions used across codebase."""
+
+import os
+import pathlib
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
+
+
+# Generic.
+def get_single_col_by_input_type(input_type, column_definition):
+ """Returns name of single column.
+
+ Args:
+ input_type: Input type of column to extract
+ column_definition: Column definition list for experiment
+ """
+
+ l = [tup[0] for tup in column_definition if tup[2] == input_type]
+
+ if len(l) != 1:
+ raise ValueError("Invalid number of columns for {}".format(input_type))
+
+ return l[0]
+
+
+def extract_cols_from_data_type(data_type, column_definition, excluded_input_types):
+ """Extracts the names of columns that correspond to a define data_type.
+
+ Args:
+ data_type: DataType of columns to extract.
+ column_definition: Column definition to use.
+ excluded_input_types: Set of input types to exclude
+
+ Returns:
+ List of names for columns with data type specified.
+ """
+ return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types]
+
+
+# Loss functions.
+def tensorflow_quantile_loss(y, y_pred, quantile):
+ """Computes quantile loss for tensorflow.
+
+ Standard quantile loss as defined in the "Training Procedure" section of
+ the main TFT paper
+
+ Args:
+ y: Targets
+ y_pred: Predictions
+ quantile: Quantile to use for loss calculations (between 0 & 1)
+
+ Returns:
+ Tensor for quantile loss.
+ """
+
+ # Checks quantile
+ if quantile < 0 or quantile > 1:
+ raise ValueError("Illegal quantile value={}! Values should be between 0 and 1.".format(quantile))
+
+ prediction_underflow = y - y_pred
+ q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum(
+ -prediction_underflow, 0.0
+ )
+
+ return tf.reduce_sum(q_loss, axis=-1)
+
+
+def numpy_normalised_quantile_loss(y, y_pred, quantile):
+ """Computes normalised quantile loss for numpy arrays.
+
+ Uses the q-Risk metric as defined in the "Training Procedure" section of the
+ main TFT paper.
+
+ Args:
+ y: Targets
+ y_pred: Predictions
+ quantile: Quantile to use for loss calculations (between 0 & 1)
+
+ Returns:
+ Float for normalised quantile loss.
+ """
+ prediction_underflow = y - y_pred
+ weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum(
+ -prediction_underflow, 0.0
+ )
+
+ quantile_loss = weighted_errors.mean()
+ normaliser = y.abs().mean()
+
+ return 2 * quantile_loss / normaliser
+
+
+# OS related functions.
+def create_folder_if_not_exist(directory):
+ """Creates folder if it doesn't exist.
+
+ Args:
+ directory: Folder path to create.
+ """
+ # Also creates directories recursively
+ pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
+
+
+# Tensorflow related functions.
+def get_default_tensorflow_config(tf_device="gpu", gpu_id=0):
+ """Creates tensorflow config for graphs to run on CPU or GPU.
+
+ Specifies whether to run graph on gpu or cpu and which GPU ID to use for multi
+ GPU machines.
+
+ Args:
+ tf_device: 'cpu' or 'gpu'
+ gpu_id: GPU ID to use if relevant
+
+ Returns:
+ Tensorflow config.
+ """
+
+ if tf_device == "cpu":
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # for training on cpu
+ tf_config = tf.ConfigProto(log_device_placement=False, device_count={"GPU": 0})
+
+ else:
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
+
+ print("Selecting GPU ID={}".format(gpu_id))
+
+ tf_config = tf.ConfigProto(log_device_placement=False)
+ tf_config.gpu_options.allow_growth = True
+
+ return tf_config
+
+
+def save(tf_session, model_folder, cp_name, scope=None):
+ """Saves Tensorflow graph to checkpoint.
+
+ Saves all trainiable variables under a given variable scope to checkpoint.
+
+ Args:
+ tf_session: Session containing graph
+ model_folder: Folder to save models
+ cp_name: Name of Tensorflow checkpoint
+ scope: Variable scope containing variables to save
+ """
+ # Save model
+ if scope is None:
+ saver = tf.train.Saver()
+ else:
+ var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
+ saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
+
+ save_path = saver.save(tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name)))
+ print("Model saved to: {0}".format(save_path))
+
+
+def load(tf_session, model_folder, cp_name, scope=None, verbose=False):
+ """Loads Tensorflow graph from checkpoint.
+
+ Args:
+ tf_session: Session to load graph into
+ model_folder: Folder containing serialised model
+ cp_name: Name of Tensorflow checkpoint
+ scope: Variable scope to use.
+ verbose: Whether to print additional debugging information.
+ """
+ # Load model proper
+ load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
+
+ print("Loading model from {0}".format(load_path))
+
+ print_weights_in_checkpoint(model_folder, cp_name)
+
+ initial_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
+
+ # Saver
+ if scope is None:
+ saver = tf.train.Saver()
+ else:
+ var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
+ saver = tf.train.Saver(var_list=var_list, max_to_keep=100000)
+ # Load
+ saver.restore(tf_session, load_path)
+ all_vars = set([v.name for v in tf.get_default_graph().as_graph_def().node])
+
+ if verbose:
+ print("Restored {0}".format(",".join(initial_vars.difference(all_vars))))
+ print("Existing {0}".format(",".join(all_vars.difference(initial_vars))))
+ print("All {0}".format(",".join(all_vars)))
+
+ print("Done.")
+
+
+def print_weights_in_checkpoint(model_folder, cp_name):
+ """Prints all weights in Tensorflow checkpoint.
+
+ Args:
+ model_folder: Folder containing checkpoint
+ cp_name: Name of checkpoint
+
+ Returns:
+
+ """
+ load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name))
+
+ print_tensors_in_checkpoint_file(file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True)
diff --git a/examples/benchmarks/TFT/requirements.txt b/examples/benchmarks/TFT/requirements.txt
new file mode 100644
index 000000000..04234aaed
--- /dev/null
+++ b/examples/benchmarks/TFT/requirements.txt
@@ -0,0 +1,3 @@
+tensorflow-gpu==1.15.0
+numpy == 1.19.4
+pandas==1.1.0
\ No newline at end of file
diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py
new file mode 100644
index 000000000..a3b4fc919
--- /dev/null
+++ b/examples/benchmarks/TFT/tft.py
@@ -0,0 +1,248 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import numpy as np
+import pandas as pd
+import tensorflow.compat.v1 as tf
+import data_formatters.base
+import expt_settings.configs
+import libs.hyperparam_opt
+import libs.tft_model
+import libs.utils as utils
+import os
+import datetime as dte
+
+
+from qlib.model.base import ModelFT
+from qlib.data.dataset import DatasetH
+from qlib.data.dataset.handler import DataHandlerLP
+
+
+# To register new datasets, please add them here.
+ALLOW_DATASET = ["Alpha158"]
+DATASET_SETTING = {
+ "Alpha158": {
+ "feature_col": ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "ROC60", "RESI10"],
+ "label_col": ["LABEL0"],
+ },
+}
+# To register new datasets, please add their configurations here.
+
+
+def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
+ return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
+
+
+def fill_test_na(test_df):
+ test_df_res = test_df.copy()
+ feature_cols = ~test_df_res.columns.str.contains("label", case=False)
+ test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
+ test_df_res.loc[:, feature_cols] = test_feature_fna
+ return test_df_res
+
+
+def process_qlib_data(df, dataset, fillna=False):
+ """Prepare data to fit the TFT model.
+
+ Args:
+ df: Original DataFrame.
+ fillna: Whether to fill the data with the mean values.
+
+ Returns:
+ Transformed DataFrame.
+
+ """
+ # Several features selected manually
+ feature_col = DATASET_SETTING[dataset]["feature_col"]
+ label_col = DATASET_SETTING[dataset]["label_col"]
+ temp_df = df.loc[:, feature_col + label_col]
+ if fillna:
+ temp_df = fill_test_na(temp_df)
+ temp_df = temp_df.swaplevel()
+ temp_df = temp_df.sort_index()
+ temp_df = temp_df.reset_index(level=0)
+ dates = pd.to_datetime(temp_df.index)
+ temp_df["date"] = dates
+ temp_df["day_of_week"] = dates.dayofweek
+ temp_df["month"] = dates.month
+ temp_df["year"] = dates.year
+ temp_df["const"] = 1.0
+ return temp_df
+
+
+def process_predicted(df, col_name):
+ """Transform the TFT predicted data into Qlib format.
+
+ Args:
+ df: Original DataFrame.
+ fillna: New column name.
+
+ Returns:
+ Transformed DataFrame.
+
+ """
+ df_res = df.copy()
+ df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+5": col_name})
+ df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
+ df_res = df_res[[col_name]]
+ return df_res
+
+
+def format_score(forecast_df, col_name="pred", label_shift=5):
+ pred = process_predicted(forecast_df, col_name=col_name)
+ pred = get_shifted_label(pred, shifts=-label_shift, col_shift=col_name)
+ pred = pred.dropna()[col_name]
+ return pred
+
+
+def transform_df(df, col_name="LABEL0"):
+ df_res = df["feature"]
+ df_res[col_name] = df["label"]
+ return df_res
+
+
+class TFTModel(ModelFT):
+ """TFT Model"""
+
+ def __init__(self, **kwargs):
+ self.model = None
+
+ def _prepare_data(self, dataset: DatasetH):
+ df_train, df_valid = dataset.prepare(
+ ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+ return transform_df(df_train), transform_df(df_valid)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ DATASET="Alpha158",
+ MODEL_FOLDER="qlib_alpha158_model",
+ LABEL_COL="LABEL0",
+ LABEL_SHIFT=5,
+ USE_GPU_ID=0,
+ **kwargs
+ ):
+
+ if DATASET not in ALLOW_DATASET:
+ raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset")
+
+ dtrain, dvalid = self._prepare_data(dataset)
+ dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
+ dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL)
+
+ train = process_qlib_data(dtrain, DATASET, fillna=True).dropna()
+ valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna()
+
+ ExperimentConfig = expt_settings.configs.ExperimentConfig
+ config = ExperimentConfig(DATASET)
+ self.data_formatter = config.make_data_formatter()
+ self.model_folder = MODEL_FOLDER
+ self.gpu_id = USE_GPU_ID
+ self.label_shift = LABEL_SHIFT
+ self.expt_name = DATASET
+ self.label_col = LABEL_COL
+
+ use_gpu = (True, self.gpu_id)
+ # ===========================Training Process===========================
+ ModelClass = libs.tft_model.TemporalFusionTransformer
+ if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter):
+ raise ValueError(
+ "Data formatters should inherit from"
+ + "AbstractDataFormatter! Type={}".format(type(self.data_formatter))
+ )
+
+ default_keras_session = tf.keras.backend.get_session()
+
+ if use_gpu[0]:
+ self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1])
+ else:
+ self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu")
+
+ self.data_formatter.set_scalers(train)
+
+ # Sets up default params
+ fixed_params = self.data_formatter.get_experiment_params()
+ params = self.data_formatter.get_default_model_params()
+
+ # Wendi: 合并调优的参数和非调优的参数
+ params = {**params, **fixed_params}
+
+ if not os.path.exists(self.model_folder):
+ os.makedirs(self.model_folder)
+ params["model_folder"] = self.model_folder
+
+ print("*** Begin training ***")
+ best_loss = np.Inf
+
+ tf.reset_default_graph()
+
+ self.tf_graph = tf.Graph()
+ with self.tf_graph.as_default():
+ self.sess = tf.Session(config=self.tf_config)
+ tf.keras.backend.set_session(self.sess)
+ self.model = ModelClass(params, use_cudnn=use_gpu[0])
+ self.sess.run(tf.global_variables_initializer())
+ self.model.fit(train_df=train, valid_df=valid)
+ print("*** Finished training ***")
+ saved_model_dir = self.model_folder + "/" + "saved_model"
+ if not os.path.exists(saved_model_dir):
+ os.makedirs(saved_model_dir)
+ self.model.save(saved_model_dir)
+
+ def extract_numerical_data(data):
+ """Strips out forecast time and identifier columns."""
+ return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]]
+
+ # p50_loss = utils.numpy_normalised_quantile_loss(
+ # extract_numerical_data(targets), extract_numerical_data(p50_forecast),
+ # 0.5)
+ # p90_loss = utils.numpy_normalised_quantile_loss(
+ # extract_numerical_data(targets), extract_numerical_data(p90_forecast),
+ # 0.9)
+ tf.keras.backend.set_session(default_keras_session)
+ print("Training completed.".format(dte.datetime.now()))
+ # ===========================Training Process===========================
+
+ def predict(self, dataset):
+ if self.model is None:
+ raise ValueError("model is not fitted yet!")
+ d_test = dataset.prepare("test", col_set=["feature", "label"])
+ d_test = transform_df(d_test)
+ d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col)
+ test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna()
+
+ use_gpu = (True, self.gpu_id)
+ # ===========================Predicting Process===========================
+ default_keras_session = tf.keras.backend.get_session()
+
+ # Sets up default params
+ fixed_params = self.data_formatter.get_experiment_params()
+ params = self.data_formatter.get_default_model_params()
+ params = {**params, **fixed_params}
+
+ print("*** Begin predicting ***")
+ tf.reset_default_graph()
+
+ with self.tf_graph.as_default():
+ tf.keras.backend.set_session(self.sess)
+ output_map = self.model.predict(test, return_targets=True)
+ targets = self.data_formatter.format_predictions(output_map["targets"])
+ p50_forecast = self.data_formatter.format_predictions(output_map["p50"])
+ p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
+ tf.keras.backend.set_session(default_keras_session)
+
+ predict = format_score(p90_forecast, "pred", 0) # self.label_shift
+ label = format_score(targets, "label", 0)
+ # ===========================Predicting Process===========================
+ return predict, label
+
+ def finetune(self, dataset: DatasetH):
+ """
+ finetune model
+ Parameters
+ ----------
+ dataset : DatasetH
+ dataset for finetuning
+ """
+ pass
diff --git a/examples/benchmarks/TFT/workflow_config_tft.yaml b/examples/benchmarks/TFT/workflow_config_tft.yaml
new file mode 100644
index 000000000..d8ee14e71
--- /dev/null
+++ b/examples/benchmarks/TFT/workflow_config_tft.yaml
@@ -0,0 +1,52 @@
+sys:
+ rel_path: .
+provider_uri: "~/.qlib/qlib_data/cn_data"
+region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: TFTModel
+ module_path: tft
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha158
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
diff --git a/examples/benchmarks/TabNet/workflow_config_tabnet.yaml b/examples/benchmarks/TabNet/workflow_config_tabnet.yaml
index 0ee95f238..5f6aa8b6d 100644
--- a/examples/benchmarks/TabNet/workflow_config_tabnet.yaml
+++ b/examples/benchmarks/TabNet/workflow_config_tabnet.yaml
@@ -44,7 +44,7 @@ task:
module_path: qlib.data.dataset
kwargs:
handler:
- class: Alpha158
+ class: ALPHA360_Denoise
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
index 407d56fb7..31eee8206 100644
--- a/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
+++ b/examples/benchmarks/XGBoost/workflow_config_xgboost.yaml
@@ -29,18 +29,15 @@ task:
class: XGBModel
module_path: qlib.contrib.model.xgboost
kwargs:
- objective: reg:linear
- n_estimators: 5000
- colsample_bytree: 0.85
- learning_rate: 0.0421
- subsample: 0.8789
- max_depth: 8
- num_leaves: 210
- num_threads: 20
- missing: -1
- min_child_weight: 1
+ eval_metric: rmse
+ colsample_bytree: 0.5
+ eta: 0.2
+ gamma: 0.55
+ max_depth: 2
+ min_child_weight: 1.0
+ n_estimators: 647
+ subsample: 0.8
nthread: 4
- tree_method: hist
dataset:
class: DatasetH
module_path: qlib.data.dataset
diff --git a/examples/portfolio_optimization_example.ipynb b/examples/portfolio_optimization_example.ipynb
new file mode 100644
index 000000000..4d6c2b3d2
--- /dev/null
+++ b/examples/portfolio_optimization_example.ipynb
@@ -0,0 +1,446 @@
+{
+ "metadata": {
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.9-final"
+ },
+ "orig_nbformat": 2,
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "import copy\n",
+ "from pathlib import Path\n",
+ "\n",
+ "import qlib\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from qlib.config import REG_CN\n",
+ "from qlib.contrib.model.gbdt import LGBModel\n",
+ "from qlib.contrib.data.handler import Alpha158\n",
+ "from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n",
+ "from qlib.contrib.evaluate import (\n",
+ " backtest as normal_backtest,\n",
+ " risk_analysis,\n",
+ ")\n",
+ "from qlib.utils import exists_qlib_data, init_instance_by_config\n",
+ "from qlib.workflow import R\n",
+ "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
+ "from qlib.utils import flatten_dict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "[35366:MainThread](2020-11-27 10:31:09,528) INFO - qlib.Initialization - [__init__.py:41] - default_conf: client.\n",
+ "[35366:MainThread](2020-11-27 10:31:09,531) WARNING - qlib.Initialization - [__init__.py:57] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
+ "[35366:MainThread](2020-11-27 10:31:09,531) INFO - qlib.Initialization - [__init__.py:76] - qlib successfully initialized based on client settings.\n",
+ "[35366:MainThread](2020-11-27 10:31:09,532) INFO - qlib.Initialization - [__init__.py:79] - data_path=/home/dongzho/.qlib/qlib_data/cn_data\n"
+ ]
+ }
+ ],
+ "source": [
+ "# use default data\n",
+ "# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
+ "provider_uri = \"~/.qlib/qlib_data/cn_data\" # target_dir\n",
+ "if not exists_qlib_data(provider_uri):\n",
+ " print(f\"Qlib data is not found in {provider_uri}\")\n",
+ " sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
+ " from get_data import GetData\n",
+ " GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
+ "qlib.init(provider_uri=provider_uri, region=REG_CN)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "market = \"csi300\"\n",
+ "benchmark = \"SH000300\""
+ ]
+ },
+ {
+ "source": [
+ "## Model Training"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "[35366:MainThread](2020-11-27 10:31:29,731) INFO - qlib.timer - [log.py:81] - Time cost: 20.103s | Loading data Done\n",
+ "[35366:MainThread](2020-11-27 10:31:30,557) INFO - qlib.timer - [log.py:81] - Time cost: 0.241s | DropnaLabel Done\n",
+ "[35366:MainThread](2020-11-27 10:31:38,518) INFO - qlib.timer - [log.py:81] - Time cost: 7.960s | CSZScoreNorm Done\n",
+ "[35366:MainThread](2020-11-27 10:31:38,519) INFO - qlib.timer - [log.py:81] - Time cost: 8.786s | fit & process data Done\n",
+ "[35366:MainThread](2020-11-27 10:31:38,520) INFO - qlib.timer - [log.py:81] - Time cost: 28.891s | Init data Done\n",
+ "[35366:MainThread](2020-11-27 10:31:38,527) INFO - qlib.workflow - [exp.py:180] - Experiment 2 starts running ...\n",
+ "[35366:MainThread](2020-11-27 10:31:38,651) INFO - qlib.workflow - [recorder.py:234] - Recorder c81375e3b5474feb9c77711babd158c3 starts running under Experiment 2 ...\n",
+ "[35366:MainThread](2020-11-27 10:31:38,652) INFO - qlib.workflow - [expm.py:251] - No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.\n",
+ "Training until validation scores don't improve for 50 rounds\n",
+ "[20]\ttrain's l2: 0.990559\tvalid's l2: 0.994332\n",
+ "[40]\ttrain's l2: 0.98687\tvalid's l2: 0.993702\n",
+ "[60]\ttrain's l2: 0.984308\tvalid's l2: 0.993503\n",
+ "[80]\ttrain's l2: 0.982202\tvalid's l2: 0.993446\n",
+ "[100]\ttrain's l2: 0.980318\tvalid's l2: 0.993423\n",
+ "[120]\ttrain's l2: 0.97854\tvalid's l2: 0.993409\n",
+ "[140]\ttrain's l2: 0.97679\tvalid's l2: 0.993413\n",
+ "[160]\ttrain's l2: 0.975116\tvalid's l2: 0.993473\n",
+ "Early stopping, best iteration is:\n",
+ "[127]\ttrain's l2: 0.977957\tvalid's l2: 0.993381\n"
+ ]
+ }
+ ],
+ "source": [
+ "###################################\n",
+ "# train model\n",
+ "###################################\n",
+ "data_handler_config = {\n",
+ " \"start_time\": \"2008-01-01\",\n",
+ " \"end_time\": \"2020-08-01\",\n",
+ " \"fit_start_time\": \"2008-01-01\",\n",
+ " \"fit_end_time\": \"2014-12-31\",\n",
+ " \"instruments\": market,\n",
+ "}\n",
+ "\n",
+ "task = {\n",
+ " \"model\": {\n",
+ " \"class\": \"LGBModel\",\n",
+ " \"module_path\": \"qlib.contrib.model.gbdt\",\n",
+ " \"kwargs\": {\n",
+ " \"loss\": \"mse\",\n",
+ " \"colsample_bytree\": 0.8879,\n",
+ " \"learning_rate\": 0.0421,\n",
+ " \"subsample\": 0.8789,\n",
+ " \"lambda_l1\": 205.6999,\n",
+ " \"lambda_l2\": 580.9768,\n",
+ " \"max_depth\": 8,\n",
+ " \"num_leaves\": 210,\n",
+ " \"num_threads\": 20,\n",
+ " },\n",
+ " },\n",
+ " \"dataset\": {\n",
+ " \"class\": \"DatasetH\",\n",
+ " \"module_path\": \"qlib.data.dataset\",\n",
+ " \"kwargs\": {\n",
+ " \"handler\": {\n",
+ " \"class\": \"Alpha158\",\n",
+ " \"module_path\": \"qlib.contrib.data.handler\",\n",
+ " \"kwargs\": data_handler_config,\n",
+ " },\n",
+ " \"segments\": {\n",
+ " \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
+ " \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
+ " \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
+ " },\n",
+ " },\n",
+ " },\n",
+ "}\n",
+ "\n",
+ "# model initiaiton\n",
+ "model = init_instance_by_config(task[\"model\"])\n",
+ "dataset = init_instance_by_config(task[\"dataset\"])\n",
+ "\n",
+ "# start exp to train model\n",
+ "with R.start(experiment_name=\"train_model\"):\n",
+ " R.log_params(**flatten_dict(task))\n",
+ " model.fit(dataset)\n",
+ " R.save_objects(trained_model=model)\n",
+ " rid = R.get_recorder().id\n"
+ ]
+ },
+ {
+ "source": [
+ "## Optimization Based Strategy"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from qlib.contrib.strategy.strategy import BaseStrategy\n",
+ "\n",
+ "\n",
+ "class OptBasedStrategy(BaseStrategy):\n",
+ " \"\"\"Optimization Based Strategy\"\"\"\n",
+ "\n",
+ " def __init__(self, data_handler, cov_estimator, optimizer):\n",
+ " self.data_handler = data_handler\n",
+ " self.cov_estimator = cov_estimator\n",
+ " self.optimizer = optimizer\n",
+ "\n",
+ " def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):\n",
+ " \"\"\"\n",
+ " Parameters\n",
+ " -----------\n",
+ " score_series : pd.Seires\n",
+ " stock_id , score.\n",
+ " current : Position()\n",
+ " current of account.\n",
+ " trade_exchange : Exchange()\n",
+ " exchange.\n",
+ " trade_date : pd.Timestamp\n",
+ " date.\n",
+ " \"\"\"\n",
+ " score_series = score_series.dropna()\n",
+ "\n",
+ " # check stock holdings, if\n",
+ " # 1. doesn't have score: target amount = 0 (force sell)\n",
+ " # 2. stock not tradable: target amount = current amount\n",
+ " current_position = current.get_stock_amount_dict()\n",
+ " target_position = {}\n",
+ " for stock_id in current_position:\n",
+ " if not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):\n",
+ " target_position[stock_id] = current_position[stock_id]\n",
+ " elif stock_id not in score_series.index:\n",
+ " target_position[stock_id] = 0\n",
+ " else:\n",
+ " # need to be solved by optimizer\n",
+ " pass\n",
+ "\n",
+ " # filter scores, if\n",
+ " # 1. kept in `amount_dict` by previous rules\n",
+ " # 2. not tradable\n",
+ " skipped = []\n",
+ " for stock_id in score_series.index:\n",
+ " if stock_id in target_position:\n",
+ " skipped.append(stock_id)\n",
+ " elif not trade_exchange.is_stock_tradable(stock_id=stock_id, trade_date=trade_date):\n",
+ " skipped.append(stock_id)\n",
+ " score_series = score_series[~score_series.index.isin(skipped)]\n",
+ "\n",
+ " # calc remaining value\n",
+ " current_value = pd.Series({\n",
+ " stock_id: current.get_stock_price(stock_id) * amount\n",
+ " for stock_id, amount in current_position.items()\n",
+ " })\n",
+ " risk_total_value = self.get_risk_degree(trade_date) * current.calculate_value()\n",
+ " traded_value = risk_total_value - current_value.loc[list(target_position)].sum()\n",
+ "\n",
+ " # portfolio init weight\n",
+ " init_weight = current_value.reindex(score_series.index, fill_value=0)\n",
+ " init_weight_sum = init_weight.sum()\n",
+ " if init_weight_sum > 0:\n",
+ " init_weight /= init_weight_sum\n",
+ "\n",
+ " # covariance estimation\n",
+ " selector = (self.data_handler.get_range_selector(pred_date, 252), score_series.index)\n",
+ " price = self.data_handler.fetch(selector, level=None, squeeze=True)\n",
+ " cov = self.cov_estimator(price)\n",
+ " cov = cov.reindex(\n",
+ " index=score_series.index, \n",
+ " columns=score_series.index, \n",
+ " #fill_value=cov.max().max()\n",
+ " )\n",
+ "\n",
+ " # optimize target portfolio\n",
+ " if init_weight.sum() > 0:\n",
+ " target_weight = self.optimizer(cov, score_series, init_weight)\n",
+ " else:\n",
+ " target_weight = self.optimizer(cov, score_series)\n",
+ " target_weight = target_weight[target_weight > 1e-6]\n",
+ " for stock_id, weight in target_weight.items():\n",
+ " try:\n",
+ " target_position[stock_id] = int(traded_value * weight / trade_exchange.get_close(stock_id, pred_date))\n",
+ " except Exception as e:\n",
+ " # TODO: unknown exception\n",
+ " print('Exception:', e)\n",
+ "\n",
+ " # for debug\n",
+ " print('trade date:', trade_date)\n",
+ " print('target weight:', target_weight.to_dict())\n",
+ " print('target position:', target_position)\n",
+ "\n",
+ " # generate order list\n",
+ " order_list = trade_exchange.generate_order_for_target_amount_position(\n",
+ " target_position=target_position,\n",
+ " current_position=current_position,\n",
+ " trade_date=trade_date,\n",
+ " )\n",
+ "\n",
+ " return order_list"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from qlib.data.dataset.loader import QlibDataLoader\n",
+ "from qlib.data.dataset.handler import DataHandler\n",
+ "from qlib.model.riskmodel import ShrinkCovEstimator\n",
+ "from qlib.portfolio.optimizer import PortfolioOptimizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "[35366:MainThread](2020-11-27 10:31:56,951) INFO - qlib.timer - [log.py:81] - Time cost: 6.763s | Loading data Done\n",
+ "[35366:MainThread](2020-11-27 10:31:56,953) INFO - qlib.timer - [log.py:81] - Time cost: 6.766s | Init data Done\n"
+ ]
+ }
+ ],
+ "source": [
+ "data_loader = QlibDataLoader([\"$close\"])\n",
+ "data_handler = DataHandler(\"all\", \"2015-01-01\", \"2020-08-01\", data_loader)\n",
+ "cov_estimator = ShrinkCovEstimator(nan_option=\"mask\")\n",
+ "optimizer = PortfolioOptimizer(\"mvo\", lamb=2, delta=0.2, tol=1e-5)\n",
+ "strategy = OptBasedStrategy(data_handler, cov_estimator, optimizer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "1': 0.08936553334387595, 'SH601800': 0.011014844457113308, 'SH601939': 0.013378001170219945, 'SH603993': 0.013820193926861863, 'SZ000338': 0.002455991798001457, 'SZ000423': 0.004893338273543826, 'SZ000538': 0.010686211189620477, 'SZ002065': 0.09095125419435357, 'SZ002074': 0.010299013738522475, 'SZ002085': 0.19844965949420615, 'SZ002236': 0.09210003831704765, 'SZ002310': 0.05664352912360013, 'SZ300017': 0.0197442255539771}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272224, 'SH600009': 604839, 'SH600018': 3097398, 'SH600028': 335726, 'SH600196': 23243, 'SH600276': 71634, 'SH600519': 17354, 'SH600585': 269686, 'SH600900': 2501521, 'SH601111': 2400659, 'SH601800': 334062, 'SH601939': 1283164, 'SH603993': 742901, 'SZ000338': 95285, 'SZ000423': 21697, 'SZ000538': 14518, 'SZ002065': 498253, 'SZ002074': 111674, 'SZ002085': 591507, 'SZ002236': 394197, 'SZ002310': 2202674, 'SZ300017': 206128}\n",
+ "target weight: {'SH600000': 0.02310668460556249, 'SH600009': 0.06170206213753432, 'SH600018': 0.027608180837257277, 'SH600028': 0.00971532319525714, 'SH600196': 0.0036133308423111116, 'SH600276': 0.093195014492093, 'SH600519': 0.013476706174774766, 'SH600585': 0.036024919027310476, 'SH600660': 0.04512159672692613, 'SH600900': 0.12506534473579556, 'SH601939': 0.013494851810297546, 'SH603993': 0.07619418669734077, 'SZ000338': 0.0024673392047414363, 'SZ000423': 0.00485981529404862, 'SZ000538': 0.010602880875660015, 'SZ002065': 0.09064325205359221, 'SZ002074': 0.0011889996597580427, 'SZ002085': 0.1982091371262038, 'SZ002236': 0.09254320484936242, 'SZ002310': 0.05152917909181458, 'SZ002466': 0.00014732765084648903, 'SZ300017': 0.019490662910321074}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 272079, 'SH600009': 604359, 'SH600018': 3095205, 'SH600028': 335471, 'SH600196': 23407, 'SH600276': 71567, 'SH600519': 17345, 'SH600585': 269447, 'SH600660': 129265, 'SH600900': 2499305, 'SH601939': 1282317, 'SH603993': 4058172, 'SZ000338': 95223, 'SZ000423': 21703, 'SZ000538': 14509, 'SZ002065': 497821, 'SZ002074': 12787, 'SZ002085': 590955, 'SZ002236': 393895, 'SZ002310': 2190685, 'SZ002466': 4483, 'SZ300017': 205994}\n",
+ "target weight: {'SH600000': 0.0014042138463464568, 'SH600009': 0.11511740651805806, 'SH600018': 0.026968513725965638, 'SH600028': 0.009566603496832042, 'SH600150': 0.016339328084607228, 'SH600276': 0.09374974543357856, 'SH600489': 0.021876512936684123, 'SH600585': 0.035840818294258524, 'SH600900': 0.12414161958870683, 'SH601888': 0.005682635273269834, 'SH601939': 0.013289788356428228, 'SH603993': 0.07491407610535435, 'SZ000338': 0.002426716760042838, 'SZ000423': 0.00492071038737461, 'SZ000503': 0.005617017904986693, 'SZ000538': 0.010859006699485451, 'SZ002065': 0.08924691553942904, 'SZ002085': 0.19757848255238786, 'SZ002236': 0.09381012783787722, 'SZ002310': 0.03737359938389514, 'SZ300017': 0.01927616131502695}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16809, 'SH600009': 1075516, 'SH600018': 3091248, 'SH600028': 335128, 'SH600150': 114804, 'SH600276': 71473, 'SH600489': 66586, 'SH600585': 268644, 'SH600900': 2496175, 'SH601888': 173824, 'SH601939': 1281108, 'SH603993': 4052802, 'SZ000338': 95107, 'SZ000423': 21684, 'SZ000503': 80461, 'SZ000538': 14507, 'SZ002065': 497197, 'SZ002085': 590211, 'SZ002236': 393412, 'SZ002310': 1573728, 'SZ300017': 205818}\n",
+ "target weight: {'SH600000': 0.0013962189421662084, 'SH600009': 0.09330267135244051, 'SH600018': 0.026443154116291615, 'SH600028': 0.009581412428525829, 'SH600150': 0.016443917649559808, 'SH600276': 0.09378402212481758, 'SH600703': 0.0005233118350013756, 'SH600741': 0.10117549074044105, 'SH600900': 0.12435147566444608, 'SH601888': 0.00560250787284307, 'SH601939': 0.013238798853730008, 'SH603993': 0.07455231781733267, 'SZ000423': 0.0048695925705555185, 'SZ000503': 0.006070996956328167, 'SZ000538': 0.010870567565742796, 'SZ002065': 0.08722983720892508, 'SZ002074': 0.00037126948590009574, 'SZ002085': 0.19840484837030906, 'SZ002236': 0.09365186287123867, 'SZ002310': 0.03806080531862309, 'SZ300017': 7.492025186876957e-05}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16889, 'SH600009': 867443, 'SH600018': 3086467, 'SH600028': 334573, 'SH600150': 114383, 'SH600276': 71360, 'SH600703': 1760, 'SH600741': 665366, 'SH600900': 2491839, 'SH601888': 173465, 'SH601939': 1278590, 'SH603993': 4045939, 'SZ000423': 21674, 'SZ000503': 80212, 'SZ000538': 14499, 'SZ002065': 496361, 'SZ002074': 4086, 'SZ002085': 589224, 'SZ002236': 392766, 'SZ002310': 1571463, 'SZ300017': 805}\n",
+ "target weight: {'SH600000': 0.0014143911110003147, 'SH600018': 0.026834186435965166, 'SH600028': 0.00961324990522086, 'SH600150': 0.015905361405158292, 'SH600276': 0.09486308638260738, 'SH600685': 1.0253334545374858e-06, 'SH600703': 0.0005108576602907958, 'SH600741': 0.10252334336233063, 'SH600900': 0.1250632059809011, 'SH601888': 0.005830869532670813, 'SH601939': 0.01336945356138906, 'SH603993': 0.07101851124599835, 'SZ000423': 0.004899981502195361, 'SZ000503': 0.006113894785564276, 'SZ000538': 0.011081925761176491, 'SZ000709': 1.06442568357325e-06, 'SZ002065': 0.08812103684766726, 'SZ002074': 0.0003564773234700175, 'SZ002085': 0.19097427428977284, 'SZ002236': 0.09299395368630246, 'SZ002310': 0.03841630892378685, 'SZ002475': 0.10001934454071283, 'SZ300017': 7.322667303400442e-05}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16886, 'SH600018': 3080789, 'SH600028': 334087, 'SH600150': 114360, 'SH600276': 71234, 'SH600685': 10, 'SH600703': 1709, 'SH600741': 663932, 'SH600900': 2486951, 'SH601888': 173417, 'SH601939': 1276335, 'SH603993': 3740672, 'SZ000423': 21667, 'SZ000503': 80191, 'SZ000538': 14495, 'SZ000709': 11, 'SZ002065': 495371, 'SZ002074': 3867, 'SZ002085': 588051, 'SZ002236': 392002, 'SZ002310': 1568834, 'SZ002475': 1264636, 'SZ300017': 809}\n",
+ "target weight: {'SH600000': 0.0013872765178790307, 'SH600018': 0.026321999857337998, 'SH600028': 0.009491029058787367, 'SH600150': 0.015749871987744815, 'SH600276': 0.09581999547114961, 'SH600703': 0.000518490273176083, 'SH600741': 0.1037547619508012, 'SH600900': 0.12396253436063161, 'SH601258': 0.02298494942988327, 'SH601888': 0.005915886046387033, 'SH601939': 0.013177336599075601, 'SH603993': 0.06888468621566025, 'SZ000423': 0.005102036718661418, 'SZ000503': 0.00602692511970311, 'SZ000538': 0.011127923667697532, 'SZ000709': 0.07688609680386178, 'SZ002065': 0.08693397271897534, 'SZ002074': 0.000347445594871718, 'SZ002085': 0.1905176824564206, 'SZ002236': 0.035835596544641496, 'SZ002475': 0.09918059167278087, 'SZ300017': 7.291118905149903e-05}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16948, 'SH600018': 3086676, 'SH600028': 334750, 'SH600150': 114560, 'SH600276': 71372, 'SH600703': 1715, 'SH600741': 665129, 'SH600900': 2491433, 'SH601258': 4190669, 'SH601888': 174070, 'SH601939': 1278836, 'SH603993': 3747283, 'SZ000423': 21744, 'SZ000503': 80490, 'SZ000538': 14538, 'SZ000709': 871429, 'SZ002065': 496245, 'SZ002074': 3887, 'SZ002085': 589120, 'SZ002236': 145147, 'SZ002475': 1268582, 'SZ300017': 814}\n",
+ "target weight: {'SH600000': 0.001373124016867567, 'SH600018': 0.02646941123076474, 'SH600028': 0.009458335378810856, 'SH600150': 0.015442533996257352, 'SH600276': 0.09620341387657301, 'SH600649': 0.012613476480118908, 'SH600703': 0.0005280976985716832, 'SH600741': 0.06577156829314017, 'SH600900': 0.12455488881029539, 'SH601258': 0.02270943336842379, 'SH601939': 0.013066707696697587, 'SH603993': 0.0649427819283919, 'SZ000423': 0.0051167756388828005, 'SZ000503': 0.006076486564538039, 'SZ000709': 0.0770418453012855, 'SZ000778': 0.08738918304165759, 'SZ002065': 0.08804613990036694, 'SZ002074': 0.00034315924263262563, 'SZ002085': 0.18241434394629127, 'SZ002475': 0.10035998625624482, 'SZ300017': 7.809604376099223e-05}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16935, 'SH600018': 3089469, 'SH600028': 334906, 'SH600150': 114496, 'SH600276': 71430, 'SH600649': 337388, 'SH600703': 1714, 'SH600741': 419916, 'SH600900': 2493978, 'SH601258': 4194599, 'SH601939': 1279661, 'SH603993': 3750968, 'SZ000423': 21734, 'SZ000503': 80440, 'SZ000709': 872293, 'SZ000778': 366855, 'SZ002065': 496756, 'SZ002074': 3880, 'SZ002085': 564610, 'SZ002475': 1269872, 'SZ300017': 812}\n",
+ "target weight: {'SH600000': 0.0013497287789570015, 'SH600018': 0.02647482761554837, 'SH600028': 0.00941080088689994, 'SH600150': 0.01556139303593115, 'SH600276': 0.09732218714743374, 'SH600649': 0.012606184789019243, 'SH600703': 0.0005334649726542859, 'SH600900': 0.12593267687041163, 'SH601258': 0.021199485570796834, 'SH601939': 0.013025993149697816, 'SH603993': 0.06446918682668012, 'SZ000423': 0.005311875734339093, 'SZ000503': 0.006125989728635501, 'SZ000709': 0.0707610058353687, 'SZ000778': 0.14004715956352495, 'SZ002065': 0.08746446321200681, 'SZ002074': 0.00033710686535540885, 'SZ002085': 0.15238971653801253, 'SZ002146': 0.042585776887618575, 'SZ002475': 0.10701429615740456, 'SZ300017': 7.667981013711115e-05}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 17031, 'SH600018': 3109084, 'SH600028': 336978, 'SH600150': 115126, 'SH600276': 71888, 'SH600649': 339316, 'SH600703': 1724, 'SH600900': 2510148, 'SH601258': 4237748, 'SH601939': 1287810, 'SH603993': 3775382, 'SZ000423': 21853, 'SZ000503': 80885, 'SZ000709': 878077, 'SZ000778': 625157, 'SZ002065': 499988, 'SZ002074': 3901, 'SZ002085': 469624, 'SZ002146': 2000993, 'SZ002475': 1278084, 'SZ300017': 814}\n",
+ "target weight: {'SH600000': 0.0013594926998639766, 'SH600009': 0.021101252574639438, 'SH600028': 0.009528554544265834, 'SH600150': 0.015013601602404225, 'SH600276': 0.09860402207319302, 'SH600649': 0.01292550325031454, 'SH600685': 0.00703471182662378, 'SH600703': 0.0005218767517596246, 'SH600900': 0.12786995199482584, 'SH601258': 0.04401496515184404, 'SH601398': 0.025932829520167643, 'SH601939': 0.0134408200189716, 'SH603993': 0.06319752369639879, 'SZ000423': 0.005221187626834546, 'SZ000503': 0.006085670359590286, 'SZ000568': 0.003081214755480397, 'SZ000709': 0.07061122716452324, 'SZ000778': 0.1379488795662632, 'SZ000839': 0.019142903464547063, 'SZ002065': 0.04714685528331623, 'SZ002074': 0.00033291622875151913, 'SZ002085': 0.11947661465752588, 'SZ002146': 0.043205942689553425, 'SZ002310': 0.0009243182551654129, 'SZ002475': 0.106199974013018, 'SZ300017': 7.709323254732814e-05}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16933, 'SH600009': 196068, 'SH600028': 337025, 'SH600150': 115100, 'SH600276': 71926, 'SH600649': 339354, 'SH600685': 75328, 'SH600703': 1713, 'SH600900': 2511928, 'SH601258': 8791935, 'SH601398': 1146896, 'SH601939': 1288215, 'SH603993': 3777819, 'SZ000423': 21728, 'SZ000503': 80869, 'SZ000568': 10375, 'SZ000709': 878683, 'SZ000778': 625604, 'SZ000839': 312116, 'SZ002065': 268413, 'SZ002074': 3860, 'SZ002085': 369761, 'SZ002146': 2002072, 'SZ002310': 40341, 'SZ002475': 1278918, 'SZ300017': 811}\n",
+ "target weight: {'SH600000': 0.0013764694393366029, 'SH600009': 0.021541655860797534, 'SH600028': 0.009752609535237182, 'SH600276': 0.06514222178877259, 'SH600649': 0.01273168785031133, 'SH600685': 0.006989932070614982, 'SH600900': 0.12998548252109676, 'SH601258': 0.13157540821422453, 'SH601398': 0.02641881439805636, 'SH601939': 0.0136141957873422, 'SH603993': 0.0602411123337629, 'SZ000503': 0.006084251045333903, 'SZ000709': 0.06977363144499521, 'SZ000778': 0.1385461140272643, 'SZ000839': 0.018579865431307987, 'SZ002065': 0.046270476942690986, 'SZ002074': 0.00025974854597178115, 'SZ002085': 0.10060756172850334, 'SZ002146': 0.043204792194791966, 'SZ002310': 0.0009022784286642987, 'SZ002466': 0.011748866835406593, 'SZ002475': 0.08457581284822364, 'SZ300017': 7.701070501151889e-05}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16938, 'SH600009': 196239, 'SH600028': 337355, 'SH600276': 46535, 'SH600649': 339479, 'SH600685': 75274, 'SH600900': 2514488, 'SH601258': 26730440, 'SH601398': 1148157, 'SH601939': 1289259, 'SH603993': 3781937, 'SZ000503': 80900, 'SZ000709': 879645, 'SZ000778': 626285, 'SZ000839': 312384, 'SZ002065': 268717, 'SZ002074': 3093, 'SZ002085': 309206, 'SZ002146': 2003901, 'SZ002310': 39782, 'SZ002466': 367691, 'SZ002475': 1026389, 'SZ300017': 812}\n",
+ "target weight: {'SH600000': 0.0013689894888766726, 'SH600009': 0.021087495457198752, 'SH600028': 0.009589419355091226, 'SH600276': 0.0644304399184473, 'SH600535': 0.016420787426513667, 'SH600649': 0.0267771761277641, 'SH600900': 0.12784455237901315, 'SH601169': 0.004374459372110214, 'SH601258': 0.13288651981531077, 'SH601398': 0.02615927477879055, 'SH601939': 0.013573361058977978, 'SH603993': 1.157895161672162e-06, 'SZ000503': 0.009069218941980683, 'SZ000709': 0.07014466816191627, 'SZ000778': 0.13956352821962528, 'SZ002065': 0.045206445945654664, 'SZ002085': 0.08649963592018277, 'SZ002146': 0.04234588186007612, 'SZ002310': 0.0008924777422846245, 'SZ002466': 0.07334842360184116, 'SZ002475': 0.08834296814868704, 'SZ300017': 7.311841306821287e-05}\n",
+ "Exception: ('SH601169', Timestamp('2017-04-25 00:00:00'))\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16929, 'SH600009': 196092, 'SH600028': 337333, 'SH600276': 46571, 'SH600535': 57649, 'SH600649': 731641, 'SH600900': 2515321, 'SH601258': 26740467, 'SH601398': 1148635, 'SH601939': 1289434, 'SH603993': 72, 'SZ000503': 122157, 'SZ000709': 879908, 'SZ000778': 626506, 'SZ002065': 268767, 'SZ002085': 267906, 'SZ002146': 2004576, 'SZ002310': 39745, 'SZ002466': 2332750, 'SZ002475': 1026858, 'SZ300017': 806}\n",
+ "target weight: {'SH600000': 0.0013439859873209908, 'SH600009': 0.02075652616964347, 'SH600028': 0.00939963933310415, 'SH600276': 0.06236017906066887, 'SH600535': 0.016369568294734148, 'SH600649': 0.025541724367766302, 'SH600900': 0.12768966131041845, 'SH601258': 0.1370446945486361, 'SH601398': 0.02601619218529119, 'SH601939': 0.013440958024818669, 'SH603993': 4.144559709761373e-06, 'SZ000503': 0.0084237188568659, 'SZ000568': 0.020576387679160105, 'SZ000709': 0.056783757531829446, 'SZ000778': 0.06920027928808208, 'SZ002008': 0.07943378393922318, 'SZ002065': 0.045339177613740886, 'SZ002085': 0.08505902525865962, 'SZ002146': 0.031624633954490035, 'SZ002310': 0.0008996156348854183, 'SZ002466': 0.0764983539831682, 'SZ002475': 0.086193992434369}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SZ300017': 812.4573136217659, 'SH600000': 16923, 'SH600009': 196076, 'SH600028': 337279, 'SH600276': 46567, 'SH600535': 57624, 'SH600649': 731549, 'SH600900': 2515891, 'SH601258': 26747448, 'SH601398': 1148886, 'SH601939': 1289307, 'SH603993': 263, 'SZ000503': 122158, 'SZ000568': 69471, 'SZ000709': 700781, 'SZ000778': 302643, 'SZ002008': 746285, 'SZ002065': 268804, 'SZ002085': 267988, 'SZ002146': 1473970, 'SZ002310': 39739, 'SZ002466': 2333288, 'SZ002475': 1027134}\n",
+ "target weight: {'SH600000': 0.0014508867295425067, 'SH600009': 0.022137935734971876, 'SH600028': 0.01003980705499816, 'SH600276': 0.065554410760754, 'SH600535': 0.017337663954140436, 'SH600649': 0.026752732524884384, 'SH600900': 0.13610376526017787, 'SH601258': 0.14230666244775886, 'SH601398': 0.027847743092481312, 'SH601939': 0.014306563408357105, 'SH603993': 2.7770868647848817e-06, 'SZ000069': 0.10104502775773525, 'SZ000503': 0.009049444347506782, 'SZ000568': 0.005686495401232644, 'SZ000778': 0.0715782861850023, 'SZ002008': 0.08609584908472251, 'SZ002065': 0.04706561122827146, 'SZ002085': 0.09099179117275048, 'SZ002146': 0.03204301334262787, 'SZ002475': 0.09241758644387384, 'SZ300017': 0.00018594702102337797}\n",
+ "target position: {'SZ000709': 700825.0269758024, 'SZ002299': 6184584.0980107365, 'SH600000': 16845, 'SH600009': 195098, 'SH600028': 335689, 'SH600276': 46340, 'SH600535': 57343, 'SH600649': 728078, 'SH600900': 2504242, 'SH601258': 26624542, 'SH601398': 1143577, 'SH601939': 1283067, 'SH603993': 160, 'SZ000069': 367637, 'SZ000503': 121565, 'SZ000568': 17626, 'SZ000778': 301250, 'SZ002008': 742790, 'SZ002065': 267559, 'SZ002085': 266737, 'SZ002146': 1467579, 'SZ002475': 1022346, 'SZ300017': 1776}\n",
+ "target weight: {'SH600000': 0.0013484985106016394, 'SH600009': 0.020750773768622693, 'SH600028': 0.009285673867962157, 'SH600104': 2.9067007814076732e-05, 'SH600196': 0.10012804077099052, 'SH600276': 0.05943563439541343, 'SH600535': 0.015902136087846228, 'SH600649': 0.025189836387314323, 'SH600900': 0.12584805827140388, 'SH601111': 6.857382365314848e-06, 'SH601258': 0.03895938466363849, 'SH601398': 0.025753888553878806, 'SH601939': 0.013275755331575599, 'SH603993': 4.249178615404585e-06, 'SZ000069': 0.09445579375504781, 'SZ000503': 0.008532747266799033, 'SZ000568': 0.0052599046052527266, 'SZ000709': 0.06003418476540357, 'SZ000778': 0.06923031488245988, 'SZ002008': 0.07903025205993618, 'SZ002065': 0.04448484691775433, 'SZ002085': 0.08426354045447453, 'SZ002146': 0.031142767130486235, 'SZ002475': 0.08747938111190227, 'SZ300017': 0.00016841662419817417}\n",
+ "target position: {'SZ002299': 6184584.0980107365, 'SH600000': 16906, 'SH600009': 195107, 'SH600028': 335257, 'SH600104': 197, 'SH600196': 630404, 'SH600276': 46282, 'SH600535': 57311, 'SH600649': 727170, 'SH600900': 2500379, 'SH601111': 203, 'SH601258': 7443096, 'SH601398': 1142014, 'SH601939': 1281361, 'SH603993': 263, 'SZ000069': 366998, 'SZ000503': 121479, 'SZ000568': 17699, 'SZ000709': 699639, 'SZ000778': 300752, 'SZ002008': 741767, 'SZ002065': 267133, 'SZ002085': 266334, 'SZ002146': 1465489, 'SZ002475': 1020693, 'SZ300017': 1756}\n",
+ "target weight: {'SH600000': 0.0012976336004362882, 'SH600009': 0.0204756895024156, 'SH600028': 0.008883617000656601, 'SH600104': 2.592943319382378e-05, 'SH600196': 0.09617041827497698, 'SH600276': 0.05681162545715886, 'SH600535': 0.015294256733040745, 'SH600649': 0.02417676167926707, 'SH600900': 0.12233373885315162, 'SH601398': 0.024531954099214746, 'SH601628': 0.005044154324745466, 'SH601888': 0.09500034426651846, 'SH601939': 0.012657033879067425, 'SH603993': 4.079522960136806e-06, 'SZ000069': 0.09054142453059062, 'SZ000503': 0.008036587259744734, 'SZ000568': 0.0049533657881637655, 'SZ000778': 0.06904486736535222, 'SZ002008': 0.06688985213943154, 'SZ002065': 0.04278977877238287, 'SZ002085': 0.0820368284038888, 'SZ002299': 0.06899317887598991, 'SZ002475': 0.08384652594205952, 'SZ300017': 0.00016035416530955983}\n",
+ "target position: {'SH601258': 7443495.190430395, 'SH600000': 16952, 'SH600009': 195676, 'SH600028': 336044, 'SH600104': 183, 'SH600196': 631454, 'SH600276': 46372, 'SH600535': 57498, 'SH600649': 728582, 'SH600900': 2504660, 'SH601398': 1143938, 'SH601628': 695470, 'SH601888': 2951253, 'SH601939': 1283887, 'SH603993': 255, 'SZ000069': 367641, 'SZ000503': 121875, 'SZ000568': 17775, 'SZ000778': 301255, 'SZ002008': 638620, 'SZ002065': 267645, 'SZ002085': 266802, 'SZ002299': 6194843, 'SZ002475': 1022527, 'SZ300017': 1765}\n",
+ "target weight: {'SH600000': 0.0013469483722729403, 'SH600028': 0.009286467498269333, 'SH600104': 2.368500734977497e-05, 'SH600196': 0.10145424564201923, 'SH600276': 0.06002237364700993, 'SH600535': 0.01588332650422844, 'SH600649': 0.025440421851940002, 'SH600900': 0.1279028471227695, 'SH601258': 0.035917606048396986, 'SH601398': 0.02559318344055778, 'SH601628': 0.005221942888216608, 'SH601888': 0.14928498761757883, 'SH601939': 0.013161430940131148, 'SH603993': 4.350147095904942e-06, 'SZ000069': 0.14038473724819095, 'SZ000503': 0.008556251357999256, 'SZ000568': 0.005243511514392524, 'SZ002008': 0.06824325050397591, 'SZ002065': 0.04420632869308568, 'SZ002085': 0.074424247013131, 'SZ002299': 0.0010812901181988855, 'SZ002475': 0.0871460668952185, 'SZ300017': 0.00017049992832446128}\n",
+ "target position: {'SZ000778': 301254.84776855103, 'SH600000': 16873, 'SH600028': 335064, 'SH600104': 156, 'SH600196': 629613, 'SH600276': 46235, 'SH600535': 57245, 'SH600649': 726346, 'SH600900': 2497776, 'SH601258': 7423462, 'SH601398': 1140689, 'SH601628': 692346, 'SH601888': 4557826, 'SH601939': 1279908, 'SH603993': 261, 'SZ000069': 551887, 'SZ000503': 121344, 'SZ000568': 17697, 'SZ002008': 636943, 'SZ002065': 266904, 'SZ002085': 231781, 'SZ002299': 97527, 'SZ002475': 1019747, 'SZ300017': 1749}\n"
+ ]
+ },
+ {
+ "output_type": "error",
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "traceback": [
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[1;31m# backtest & analysis\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 31\u001b[0m \u001b[0mpar\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mport_analysis_config\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 32\u001b[1;33m \u001b[0mpar\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgenerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[1;32md:\\qlib\\qlib\\workflow\\record_temp.py\u001b[0m in \u001b[0;36mgenerate\u001b[1;34m(self, **kwargs)\u001b[0m\n\u001b[0;32m 230\u001b[0m \u001b[1;31m# custom strategy and get backtest\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 231\u001b[0m \u001b[0mpred_score\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 232\u001b[1;33m \u001b[0mreport_normal\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpositions_normal\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnormal_backtest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpred_score\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstrategy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstrategy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbacktest_config\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_objects\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"report_normal.pkl\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mreport_normal\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0martifact_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 234\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecorder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msave_objects\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m\"positions_normal.pkl\"\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mpositions_normal\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0martifact_path\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPortAnaRecord\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32md:\\qlib\\qlib\\contrib\\evaluate.py\u001b[0m in \u001b[0;36mbacktest\u001b[1;34m(pred, account, shift, benchmark, verbose, **kwargs)\u001b[0m\n\u001b[0;32m 269\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 270\u001b[0m \u001b[0maccount\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0maccount\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 271\u001b[1;33m \u001b[0mbenchmark\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbenchmark\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 272\u001b[0m )\n\u001b[0;32m 273\u001b[0m \u001b[1;31m# for compatibility of the old API. return the dict positions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32md:\\qlib\\qlib\\contrib\\backtest\\backtest.py\u001b[0m in \u001b[0;36mbacktest\u001b[1;34m(pred, strategy, trade_exchange, shift, verbose, account, benchmark)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[0mtrade_exchange\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrade_exchange\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[0mpred_date\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mpred_date\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mtrade_date\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtrade_date\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m )\n\u001b[0;32m 104\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m\u001b[0m in \u001b[0;36mgenerate_order_list\u001b[1;34m(self, score_series, current, trade_exchange, pred_date, trade_date)\u001b[0m\n\u001b[0;32m 76\u001b[0m \u001b[1;31m# optimize target portfolio\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 77\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minit_weight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msum\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 78\u001b[1;33m \u001b[0mtarget_weight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcov\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscore_series\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minit_weight\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 79\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 80\u001b[0m \u001b[0mtarget_weight\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcov\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mscore_series\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 101\u001b[0m \u001b[1;31m# optimize\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m \u001b[0mw\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_optimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 103\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 104\u001b[0m \u001b[1;31m# restore index if needed\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_optimize\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 126\u001b[0m \u001b[1;31m# mean-variance\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 127\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmethod\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOPT_MVO\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 128\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_optimize_mvo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 129\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 130\u001b[0m \u001b[1;31m# risk parity\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_optimize_mvo\u001b[1;34m(self, S, u, w0)\u001b[0m\n\u001b[0;32m 162\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mlamb\u001b[0m\u001b[0;31m`\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0mthe\u001b[0m \u001b[0mrisk\u001b[0m \u001b[0maversion\u001b[0m \u001b[0mparameter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 163\u001b[0m \"\"\"\n\u001b[1;32m--> 164\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_solve\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_objective_mvo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mu\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_constrains\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 165\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 166\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_optimize_rp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mS\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mw0\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32md:\\qlib\\qlib\\portfolio\\optimizer.py\u001b[0m in \u001b[0;36m_solve\u001b[1;34m(self, n, obj, bounds, cons)\u001b[0m\n\u001b[0;32m 252\u001b[0m \u001b[1;31m# solve\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[0mx0\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mones\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mn\u001b[0m \u001b[1;31m# init results\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 254\u001b[1;33m \u001b[0msol\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mso\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mminimize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwrapped_obj\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbounds\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mbounds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconstraints\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcons\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtol\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 255\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0msol\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msuccess\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 256\u001b[0m \u001b[0mwarnings\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34mf\"optimization not success ({sol.status})\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;32m~\\AppData\\Local\\Continuum\\miniconda3\\envs\\qlib\\lib\\site-packages\\scipy\\optimize\\_minimize.py\u001b[0m in \u001b[0;36mminimize\u001b[1;34m(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)\u001b[0m\n\u001b[0;32m 624\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mmeth\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'slsqp'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 625\u001b[0m return _minimize_slsqp(fun, x0, args, jac, bounds,\n\u001b[1;32m--> 626\u001b[1;33m constraints, callback=callback, **options)\n\u001b[0m\u001b[0;32m 627\u001b[0m \u001b[1;32melif\u001b[0m \u001b[0mmeth\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'trust-constr'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 628\u001b[0m return _minimize_trustregion_constr(fun, x0, args, jac, hess, hessp,\n",
+ "\u001b[1;32m~\\AppData\\Local\\Continuum\\miniconda3\\envs\\qlib\\lib\\site-packages\\scipy\\optimize\\slsqp.py\u001b[0m in \u001b[0;36m_minimize_slsqp\u001b[1;34m(func, x0, args, jac, bounds, constraints, maxiter, ftol, iprint, disp, eps, callback, finite_diff_rel_step, **unknown_options)\u001b[0m\n\u001b[0;32m 419\u001b[0m n1, n2, n3)\n\u001b[0;32m 420\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 421\u001b[1;33m \u001b[1;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# objective and constraint evaluation required\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 422\u001b[0m \u001b[0mfx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 423\u001b[0m \u001b[0mc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_eval_constraint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcons\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
+ "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "###################################\n",
+ "# prediction, backtest & analysis\n",
+ "###################################\n",
+ "port_analysis_config = {\n",
+ " \"strategy\": strategy,\n",
+ " \"backtest\": {\n",
+ " \"verbose\": False,\n",
+ " \"limit_threshold\": 0.095,\n",
+ " \"account\": 100000000,\n",
+ " \"benchmark\": benchmark,\n",
+ " \"deal_price\": \"close\",\n",
+ " \"open_cost\": 0.0005,\n",
+ " \"close_cost\": 0.0015,\n",
+ " \"min_cost\": 5,\n",
+ " },\n",
+ "}\n",
+ "\n",
+ "\n",
+ "# backtest and analysis\n",
+ "with R.start(experiment_name=\"backtest_analysis\"):\n",
+ " recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
+ " model = recorder.load_object(\"trained_model\")\n",
+ "\n",
+ " # prediction\n",
+ " recorder = R.get_recorder()\n",
+ " ba_rid = recorder.id\n",
+ " sr = SignalRecord(model, dataset, recorder)\n",
+ " sr.generate()\n",
+ "\n",
+ " # backtest & analysis\n",
+ " par = PortAnaRecord(recorder, port_analysis_config)\n",
+ " par.generate()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ]
+}
\ No newline at end of file
diff --git a/examples/run_all_model.py b/examples/run_all_model.py
index f8894afd3..2f6c4299e 100644
--- a/examples/run_all_model.py
+++ b/examples/run_all_model.py
@@ -10,6 +10,7 @@ import shutil
import tempfile
import statistics
from pathlib import Path
+from operator import xor
from subprocess import Popen, PIPE
from threading import Thread
from pprint import pprint
@@ -174,11 +175,21 @@ def cal_mean_std(results) -> dict:
# function to get all the folders benchmark folder
-def get_all_folders() -> dict:
+def get_all_folders(models, exclude) -> dict:
folders = dict()
+ if isinstance(models, str):
+ model_list = models.split(",")
+ models = [m.lower().strip("[ ]") for m in model_list]
+ elif isinstance(models, list):
+ models = [m.lower() for m in models]
+ elif models is None:
+ models = [f.name.lower() for f in os.scandir("benchmarks")]
+ else:
+ raise ValueError("Input models type is not supported. Please provide str or list without space.")
for f in os.scandir("benchmarks"):
- path = Path("benchmarks") / f.name
- if f.name != "TFT":
+ add = xor(bool(f.name.lower() in models), bool(exclude))
+ if add:
+ path = Path("benchmarks") / f.name
folders[f.name] = str(path.resolve())
return folders
@@ -226,13 +237,44 @@ def gen_and_save_md_table(metrics):
# function to run the all the models
-def run(times=1):
+def run(times=1, models=None, exclude=False):
"""
Please be aware that this function can only work under Linux. MacOS and Windows will be supported in the future.
Any PR to enhance this method is highly welcomed.
+
+ Parameters:
+ -----------
+ times : int
+ determines how many times the model should be running.
+ models : str or list
+ determines the specific model or list of models to run or exclude.
+ exclude : boolean
+ determines whether the model being used is excluded or included.
+
+ Usage:
+ -------
+ Here are some use cases of the function in the bash:
+
+ .. code-block:: bash
+
+ # Case 1 - run all models multiple times
+ python run_all_model.py 3
+
+ # Case 2 - run specific models multiple times
+ python run_all_model.py 3 dnn
+
+ # Case 3 - run other models except those are given as arguments for multiple times
+ python run_all_model.py 3 [dnn,tft,lstm] True
+
+ # Case 4 - run specific models for one time
+ python run_all_model.py --models=[dnn,lightgbm]
+
+ # Case 5 - run other models except those are given as aruments for one time
+ python run_all_model.py --models=[dnn,tft,sfm] --exclude=True
+
"""
# get all folders
- folders = get_all_folders()
+ folders = get_all_folders(models, exclude)
# set up
compatible = True
if sys.version_info < (3, 3):
diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb
index 1b4183b29..692e52078 100644
--- a/examples/workflow_by_code.ipynb
+++ b/examples/workflow_by_code.ipynb
@@ -31,7 +31,8 @@
")\n",
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
"from qlib.workflow import R\n",
- "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord"
+ "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
+ "from qlib.utils import flatten_dict"
]
},
{
@@ -129,7 +130,7 @@
"\n",
"# start exp to train model\n",
"with R.start(experiment_name=\"train_model\"):\n",
- " R.log_paramters(**flatten_dict(task))\n",
+ " R.log_params(**flatten_dict(task))\n",
" model.fit(dataset)\n",
" R.save_objects(trained_model=model)\n",
" rid = R.get_recorder().id\n"
@@ -337,4 +338,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
-}
\ No newline at end of file
+}
diff --git a/examples/workflow_by_code_alstm.py b/examples/workflow_by_code_alstm.py
new file mode 100644
index 000000000..8fd9e3565
--- /dev/null
+++ b/examples/workflow_by_code_alstm.py
@@ -0,0 +1,138 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import sys
+from pathlib import Path
+
+import qlib
+import pandas as pd
+from qlib.config import REG_CN
+from qlib.contrib.strategy.strategy import TopkDropoutStrategy
+from qlib.contrib.evaluate import (
+ backtest as normal_backtest,
+ risk_analysis,
+)
+from qlib.utils import exists_qlib_data
+from qlib.utils import init_instance_by_config
+
+if __name__ == "__main__":
+
+ # use default data
+ provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
+ if not exists_qlib_data(provider_uri):
+ print(f"Qlib data is not found in {provider_uri}")
+ sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
+ from get_data import GetData
+
+ GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
+
+ qlib.init(provider_uri=provider_uri, region=REG_CN)
+
+ MARKET = "csi300"
+ BENCHMARK = "SH000300"
+
+ ###################################
+ # train model
+ ###################################
+ DATA_HANDLER_CONFIG = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": MARKET,
+ }
+
+ TRAINER_CONFIG = {
+ "train_start_time": "2008-01-01",
+ "train_end_time": "2014-12-31",
+ "validate_start_time": "2015-01-01",
+ "validate_end_time": "2016-12-31",
+ "test_start_time": "2017-01-01",
+ "test_end_time": "2020-08-01",
+ }
+
+ task = {
+ "model": {
+ "class": "ALSTM",
+ "module_path": "qlib.contrib.model.pytorch_alstm",
+ "kwargs": {
+ "d_feat": 6,
+ "hidden_size": 64,
+ "num_layers": 2,
+ "dropout": 0.0,
+ "n_epochs": 200,
+ "lr": 1e-3,
+ "early_stop": 20,
+ "batch_size": 800,
+ "metric": "IC",
+ "loss": "mse",
+ "seed": 0,
+ "GPU": "0",
+ "rnn_type": "GRU",
+ },
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "ALPHA360_Denoise",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": DATA_HANDLER_CONFIG,
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+ }
+ # You shoud record the data in specific sequence
+ # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
+ }
+
+ model = init_instance_by_config(task["model"])
+ dataset = init_instance_by_config(task["dataset"])
+ model.fit(dataset)
+
+ pred_score = model.predict(dataset)
+
+ # save pred_score to file
+ pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
+ pred_score_path.parent.mkdir(exist_ok=True, parents=True)
+ pred_score.to_pickle(pred_score_path)
+
+ ###################################
+ # backtest
+ ###################################
+ STRATEGY_CONFIG = {
+ "topk": 50,
+ "n_drop": 5,
+ }
+ BACKTEST_CONFIG = {
+ "verbose": False,
+ "limit_threshold": 0.095,
+ "account": 100000000,
+ "benchmark": BENCHMARK,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ }
+
+ # use default strategy
+ # custom Strategy, refer to: TODO: Strategy API url
+ strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
+ report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
+
+ ###################################
+ # analyze
+ # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
+ ###################################
+ analysis = dict()
+ analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
+ analysis["excess_return_with_cost"] = risk_analysis(
+ report_normal["return"] - report_normal["bench"] - report_normal["cost"]
+ )
+ analysis_df = pd.concat(analysis) # type: pd.DataFrame
+ print(analysis_df)
diff --git a/examples/workflow_by_code_gats.py b/examples/workflow_by_code_gats.py
index 3bb4edf08..20f3ae552 100644
--- a/examples/workflow_by_code_gats.py
+++ b/examples/workflow_by_code_gats.py
@@ -7,19 +7,15 @@ from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
-from qlib.contrib.model.pytorch_gats import GAT
-from qlib.contrib.data.handler import ALPHA360_Denoise
+
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data
-
-# from qlib.model.learner import train_model
from qlib.utils import init_instance_by_config
-import pickle
if __name__ == "__main__":
@@ -65,17 +61,16 @@ if __name__ == "__main__":
"d_feat": 6,
"hidden_size": 64,
"num_layers": 2,
- "dropout": 0.0,
+ "dropout": 0.7,
"n_epochs": 200,
- "lr": 1e-3,
+ "lr": 1e-4,
"early_stop": 20,
- "batch_size": 800,
"metric": "loss",
"loss": "mse",
"base_model": "LSTM",
"with_pretrain": True,
"seed": 0,
- "GPU": 0,
+ "GPU": "0",
},
},
"dataset": {
@@ -98,7 +93,6 @@ if __name__ == "__main__":
# "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
}
- # model = train_model(task)
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model.fit(dataset)
diff --git a/examples/workflow_by_code_gru.py b/examples/workflow_by_code_gru.py
index fdd0d9220..dece520d1 100644
--- a/examples/workflow_by_code_gru.py
+++ b/examples/workflow_by_code_gru.py
@@ -70,7 +70,7 @@ if __name__ == "__main__":
"lr": 1e-3,
"early_stop": 20,
"batch_size": 800,
- "metric": "IC",
+ "metric": "loss",
"loss": "mse",
"seed": 0,
"GPU": 0,
diff --git a/examples/workflow_by_code_hats.py b/examples/workflow_by_code_hats.py
new file mode 100644
index 000000000..64bc860b4
--- /dev/null
+++ b/examples/workflow_by_code_hats.py
@@ -0,0 +1,136 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+import sys
+from pathlib import Path
+import qlib
+import pandas as pd
+from qlib.config import REG_CN
+from qlib.contrib.strategy.strategy import TopkDropoutStrategy
+from qlib.contrib.evaluate import (
+ backtest as normal_backtest,
+ risk_analysis,
+)
+from qlib.utils import exists_qlib_data
+from qlib.utils import init_instance_by_config
+
+if __name__ == "__main__":
+
+ # use default data
+ provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
+ if not exists_qlib_data(provider_uri):
+ print(f"Qlib data is not found in {provider_uri}")
+ sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
+ from get_data import GetData
+
+ GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
+
+ qlib.init(provider_uri=provider_uri, region=REG_CN)
+
+ MARKET = "csi300"
+ BENCHMARK = "SH000300"
+
+ ###################################
+ # train model
+ ###################################
+ DATA_HANDLER_CONFIG = {
+ "start_time": "2008-01-01",
+ "end_time": "2020-08-01",
+ "fit_start_time": "2008-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": MARKET,
+ }
+
+ TRAINER_CONFIG = {
+ "train_start_time": "2008-01-01",
+ "train_end_time": "2014-12-31",
+ "validate_start_time": "2015-01-01",
+ "validate_end_time": "2016-12-31",
+ "test_start_time": "2017-01-01",
+ "test_end_time": "2020-08-01",
+ }
+
+ task = {
+ "model": {
+ "class": "HATS",
+ "module_path": "qlib.contrib.model.pytorch_hats",
+ "kwargs": {
+ "d_feat": 6,
+ "hidden_size": 64,
+ "num_layers": 2,
+ "dropout": 0.7,
+ "n_epochs": 200,
+ "lr": 1e-4,
+ "early_stop": 20,
+ "metric": "loss",
+ "loss": "mse",
+ "base_model": "LSTM",
+ "seed": 0,
+ "GPU": "2",
+ },
+ },
+ "dataset": {
+ "class": "DatasetH",
+ "module_path": "qlib.data.dataset",
+ "kwargs": {
+ "handler": {
+ "class": "ALPHA360_Denoise",
+ "module_path": "qlib.contrib.data.handler",
+ "kwargs": DATA_HANDLER_CONFIG,
+ },
+ "segments": {
+ "train": ("2008-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2016-12-31"),
+ "test": ("2017-01-01", "2020-08-01"),
+ },
+ },
+ }
+ # You shoud record the data in specific sequence
+ # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'],
+ }
+
+ model = init_instance_by_config(task["model"])
+ dataset = init_instance_by_config(task["dataset"])
+ model.fit(dataset, save_path="benchmarks/HATS/model_hat.pkl")
+
+ pred_score = model.predict(dataset)
+
+ # save pred_score to file
+ pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser()
+ pred_score_path.parent.mkdir(exist_ok=True, parents=True)
+ pred_score.to_pickle(pred_score_path)
+
+ ###################################
+ # backtest
+ ###################################
+ STRATEGY_CONFIG = {
+ "topk": 50,
+ "n_drop": 5,
+ }
+ BACKTEST_CONFIG = {
+ "verbose": False,
+ "limit_threshold": 0.095,
+ "account": 100000000,
+ "benchmark": BENCHMARK,
+ "deal_price": "close",
+ "open_cost": 0.0005,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ }
+
+ # use default strategy
+ # custom Strategy, refer to: TODO: Strategy API url
+ strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
+ report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG)
+
+ ###################################
+ # analyze
+ # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb
+ ###################################
+ analysis = dict()
+ analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
+ analysis["excess_return_with_cost"] = risk_analysis(
+ report_normal["return"] - report_normal["bench"] - report_normal["cost"]
+ )
+ analysis_df = pd.concat(analysis) # type: pd.DataFrame
+ print(analysis_df)
diff --git a/examples/workflow_by_code_sfm.py b/examples/workflow_by_code_sfm.py
index 1942bfb33..5bd91ded8 100644
--- a/examples/workflow_by_code_sfm.py
+++ b/examples/workflow_by_code_sfm.py
@@ -1,5 +1,15 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
+# Copyright (c) Microsoft Corporation.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import sys
from pathlib import Path
@@ -62,21 +72,22 @@ if __name__ == "__main__":
"kwargs": {
"d_feat": 6,
"hidden_size": 64,
- "output_dim": 1,
- "freq_dim": 15,
+ "output_dim": 32,
+ "freq_dim": 25,
"dropout_W": 0.5,
"dropout_U": 0.5,
- "n_epochs": 10,
+ "n_epochs": 15,
"lr": 1e-3,
- "batch_size": 800,
+ "metric": "",
+ "batch_size": 1600,
"early_stop": 20,
"eval_steps": 5,
"loss": "mse",
"lr_decay": 0.96,
"lr_decay_steps": 100,
- "optimizer": "gd",
- "GPU": 1,
- "seed": 0,
+ "optimizer": "adam",
+ "GPU": 3,
+ "seed": 710,
},
},
"dataset": {
diff --git a/qlib/config.py b/qlib/config.py
index 640701ee5..869ea99c9 100644
--- a/qlib/config.py
+++ b/qlib/config.py
@@ -64,7 +64,7 @@ class Config:
REG_CN = "cn"
REG_US = "us"
-NUM_USABLE_CPU = multiprocessing.cpu_count() - 2
+NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
_default_config = {
# data provider config
diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py
index 3668a0cc0..e61d26254 100644
--- a/qlib/contrib/data/handler.py
+++ b/qlib/contrib/data/handler.py
@@ -10,6 +10,28 @@ from inspect import getfullargspec
import copy
+def check_transform_proc(proc_l, fit_start_time, fit_end_time):
+ new_l = []
+ for p in proc_l:
+ if not isinstance(p, Processor):
+ klass, pkwargs = get_cls_kwargs(p, processor_module)
+ args = getfullargspec(klass).args
+ if "fit_start_time" in args and "fit_end_time" in args:
+ assert (
+ fit_start_time is not None and fit_end_time is not None
+ ), "Make sure `fit_start_time` and `fit_end_time` are not None."
+ pkwargs.update(
+ {
+ "fit_start_time": fit_start_time,
+ "fit_end_time": fit_end_time,
+ }
+ )
+ new_l.append({"class": klass.__name__, "kwargs": pkwargs})
+ else:
+ new_l.append(p)
+ return new_l
+
+
class ALPHA360_Denoise(DataHandlerLP):
def __init__(self, instruments="csi500", start_time=None, end_time=None, fit_start_time=None, fit_end_time=None):
data_loader = {
@@ -83,28 +105,42 @@ class ALPHA360_Denoise(DataHandlerLP):
return fields, names
+_DEFAULT_LEARN_PROCESSORS = [
+ {"class": "DropnaLabel"},
+ {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
+]
+_DEFAULT_INFER_PROCESSORS = [
+ {"class": "ProcessInf", "kwargs": {}},
+ {"class": "ZScoreNorm", "kwargs": {}},
+ {"class": "Fillna", "kwargs": {}},
+]
+
+
class ALPHA360(DataHandlerLP):
- def __init__(self, instruments="csi500", start_time=None, end_time=None, fit_start_time=None, fit_end_time=None):
+ def __init__(
+ self,
+ instruments="csi500",
+ start_time=None,
+ end_time=None,
+ infer_processors=_DEFAULT_INFER_PROCESSORS,
+ learn_processors=_DEFAULT_LEARN_PROCESSORS,
+ fit_start_time=None,
+ fit_end_time=None,
+ **kwargs,
+ ):
+ infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
+ learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
+
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": {
"feature": self.get_feature_config(),
- "label": self.get_label_config(),
+ "label": kwargs.get("label", self.get_label_config()),
},
},
}
- learn_processors = [
- {"class": "DropnaLabel", "kwargs": {"fields_group": "label"}},
- {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
- ]
- infer_processors = [
- {"class": "ProcessInf", "kwargs": {}},
- {"class": "ZscoreNorm", "kwargs": {"fit_start_time": fit_start_time, "fit_end_time": fit_end_time}},
- {"class": "Fillna", "kwargs": {}},
- ]
-
super().__init__(
instruments,
start_time,
@@ -168,39 +204,19 @@ class Alpha158(DataHandlerLP):
start_time=None,
end_time=None,
infer_processors=[],
- learn_processors=["DropnaLabel", {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}],
+ learn_processors=_DEFAULT_LEARN_PROCESSORS,
fit_start_time=None,
fit_end_time=None,
process_type=DataHandlerLP.PTYPE_A
+ **kwargs,
):
- def check_transform_proc(proc_l):
- new_l = []
- for p in proc_l:
- if not isinstance(p, Processor):
- klass, pkwargs = get_cls_kwargs(p, processor_module)
- args = getfullargspec(klass).args
- if "fit_start_time" in args and "fit_end_time" in args:
- assert (
- fit_start_time is not None and fit_end_time is not None
- ), "Make sure `fit_start_time` and `fit_end_time` are not None."
- pkwargs.update(
- {
- "fit_start_time": fit_start_time,
- "fit_end_time": fit_end_time,
- }
- )
- new_l.append({"class": klass.__name__, "kwargs": pkwargs})
- else:
- new_l.append(p)
- return new_l
-
- infer_processors = check_transform_proc(infer_processors)
- learn_processors = check_transform_proc(learn_processors)
+ infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
+ learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
- "config": {"feature": self.get_feature_config(), "label": self.get_label_config()},
+ "config": {"feature": self.get_feature_config(), "label": kwargs.get("label", self.get_label_config())},
},
}
super().__init__(
diff --git a/qlib/contrib/estimator/__init__.py b/qlib/contrib/estimator/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/qlib/contrib/estimator/config.py b/qlib/contrib/estimator/config.py
deleted file mode 100644
index 0d782c412..000000000
--- a/qlib/contrib/estimator/config.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-import yaml
-import copy
-import os
-import json
-import tempfile
-from pathlib import Path
-from ...config import REG_CN
-
-
-class EstimatorConfigManager(object):
- def __init__(self, config_path):
-
- if not config_path:
- raise ValueError("Config path is invalid.")
- self.config_path = config_path
-
- with open(config_path) as fp:
- config = yaml.load(fp, Loader=yaml.FullLoader)
- self.config = copy.deepcopy(config)
-
- self.ex_config = ExperimentConfig(config.get("experiment", dict()), self)
- self.data_config = DataConfig(config.get("data", dict()), self)
- self.model_config = ModelConfig(config.get("model", dict()), self)
- self.trainer_config = TrainerConfig(config.get("trainer", dict()), self)
- self.strategy_config = StrategyConfig(config.get("strategy", dict()), self)
- self.backtest_config = BacktestConfig(config.get("backtest", dict()), self)
- self.qlib_data_config = QlibDataConfig(config.get("qlib_data", dict()), self)
-
- # If the start_date and end_date are not given in data_config, they will be referred from the trainer_config.
- handler_start_date = self.data_config.handler_parameters.get("start_date", None)
- handler_end_date = self.data_config.handler_parameters.get("end_date", None)
- if handler_start_date is None:
- self.data_config.handler_parameters["start_date"] = self.trainer_config.parameters["train_start_date"]
- if handler_end_date is None:
- self.data_config.handler_parameters["end_date"] = self.trainer_config.parameters["test_end_date"]
-
-
-class ExperimentConfig(object):
- TRAIN_MODE = "train"
- TEST_MODE = "test"
-
- OBSERVER_FILE_STORAGE = "file_storage"
- OBSERVER_MONGO = "mongo"
-
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for experiment
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.name = config.get("name", "test_experiment")
- # The dir of the result of all the experiments
- self.global_dir = config.get("dir", os.path.dirname(CONFIG_MANAGER.config_path))
- # The dir of the result of current experiment
- self.ex_dir = os.path.join(self.global_dir, self.name)
- if not os.path.exists(self.ex_dir):
- os.makedirs(self.ex_dir)
- self.tmp_run_dir = tempfile.mkdtemp(dir=self.ex_dir)
- self.mode = config.get("mode", ExperimentConfig.TRAIN_MODE)
- self.sacred_dir = os.path.join(self.ex_dir, "sacred")
- self.observer_type = config.get("observer_type", ExperimentConfig.OBSERVER_FILE_STORAGE)
- self.mongo_url = config.get("mongo_url", None)
- self.db_name = config.get("db_name", None)
- self.finetune = config.get("finetune", False)
-
- # The path of the experiment id of the experiment
- self.exp_info_path = config.get("exp_info_path", os.path.join(self.ex_dir, "exp_info.json"))
- exp_info_dir = Path(self.exp_info_path).parent
- exp_info_dir.mkdir(parents=True, exist_ok=True)
-
- # Test mode config
- loader_args = config.get("loader", dict())
- if self.mode == ExperimentConfig.TEST_MODE or self.finetune:
- loader_exp_info_path = loader_args.get("exp_info_path", None)
- self.loader_model_index = loader_args.get("model_index", None)
- if (loader_exp_info_path is not None) and (os.path.exists(loader_exp_info_path)):
- with open(loader_exp_info_path) as fp:
- loader_dict = json.load(fp)
- for k, v in loader_dict.items():
- setattr(self, "loader_{}".format(k), v)
- # Check loader experiment id
- assert hasattr(self, "loader_id"), "If mode is test or finetune is True, loader must contain id."
- else:
- self.loader_id = loader_args.get("id", None)
- if self.loader_id is None:
- raise ValueError("If mode is test or finetune is True, loader must contain id.")
-
- self.loader_observer_type = loader_args.get("observer_type", self.observer_type)
- self.loader_name = loader_args.get("name", self.name)
- self.loader_dir = loader_args.get("dir", self.global_dir)
-
- self.loader_mongo_url = loader_args.get("mongo_url", self.mongo_url)
- self.loader_db_name = loader_args.get("db_name", self.db_name)
-
-
-class DataConfig(object):
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for data
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.handler_module_path = config.get("module_path", "qlib.contrib.data.handler")
- self.handler_class = config.get("class", "ALPHA360")
- self.handler_parameters = config.get("args", dict())
- self.handler_filter = config.get("filter", dict())
- # Update provider uri.
-
-
-class ModelConfig(object):
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for model
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.model_class = config.get("class", "Model")
- self.model_module_path = config.get("module_path", "qlib.model")
- self.save_dir = os.path.join(CONFIG_MANAGER.ex_config.tmp_run_dir, "model")
- self.save_path = config.get("save_path", os.path.join(self.save_dir, "model.bin"))
- self.parameters = config.get("args", dict())
- # Make dir if need.
- if not os.path.exists(self.save_dir):
- os.makedirs(self.save_dir)
-
-
-class TrainerConfig(object):
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for trainer
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.trainer_class = config.get("class", "StaticTrainer")
- self.trainer_module_path = config.get("module_path", "qlib.contrib.estimator.trainer")
- self.parameters = config.get("args", dict())
-
-
-class StrategyConfig(object):
- def __init__(self, config, CONFIG_MANAGER):
- """__init__
-
- :param config: The config dict for strategy
- :param CONFIG_MANAGER: The estimator config manager
- """
- self.strategy_class = config.get("class", "TopkDropoutStrategy")
- self.strategy_module_path = config.get("module_path", "qlib.contrib.strategy.strategy")
- self.parameters = config.get("args", dict())
-
-
-class BacktestConfig(object):
- def __init__(self, config, CONFIG_MANAGE):
- """__init__
-
- :param config: The config dict for strategy
- :param CONFIG_MANAGE: The estimator config manager
- """
- self.normal_backtest_parameters = config.get("normal_backtest_args", dict())
- self.long_short_backtest_parameters = config.get("long_short_backtest_args", dict())
-
-
-class QlibDataConfig(object):
- def __init__(self, config, CONFIG_MANAGE):
- """__init__
-
- :param config: The config dict for qlib_client
- :param CONFIG_MANAGE: The estimator config manager
- """
- self.provider_uri = config.pop("provider_uri", "~/.qlib/qlib_data/cn_data")
- self.auto_mount = config.pop("auto_mount", False)
- self.mount_path = config.pop("mount_path", "~/.qlib/qlib_data/cn_data")
- self.region = config.pop("region", REG_CN)
- self.args = config
diff --git a/qlib/contrib/estimator/estimator.py b/qlib/contrib/estimator/estimator.py
deleted file mode 100644
index 56495e5eb..000000000
--- a/qlib/contrib/estimator/estimator.py
+++ /dev/null
@@ -1,328 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-# coding=utf-8
-
-import pandas as pd
-
-import os
-import copy
-import json
-import yaml
-import pickle
-
-import qlib
-from ..evaluate import risk_analysis
-from ..evaluate import backtest as normal_backtest
-from ..evaluate import long_short_backtest
-from .config import ExperimentConfig
-from .fetcher import create_fetcher_with_config
-
-from ...log import get_module_logger, TimeInspector
-from ...utils import get_module_by_module_path, compare_dict_value
-
-
-class Estimator(object):
- def __init__(self, config_manager, sacred_ex):
-
- # Set logger.
- self.logger = get_module_logger("Estimator")
-
- # 1. Set config manager.
- self.config_manager = config_manager
-
- # 2. Set configs.
- self.ex_config = config_manager.ex_config
- self.data_config = config_manager.data_config
- self.model_config = config_manager.model_config
- self.trainer_config = config_manager.trainer_config
- self.strategy_config = config_manager.strategy_config
- self.backtest_config = config_manager.backtest_config
-
- # If experiment.mode is test or experiment.finetune is True, load the experimental results in the loader
- if self.ex_config.mode == self.ex_config.TEST_MODE or self.ex_config.finetune:
- self.compare_config_with_config_manger(self.config_manager)
-
- # 3. Set sacred_experiment.
- self.ex = sacred_ex
-
- # 4. Init data handler.
- self.data_handler = None
- self._init_data_handler()
-
- # 5. Init trainer.
- self.trainer = None
- self._init_trainer()
-
- # 6. Init strategy.
- self.strategy = None
- self._init_strategy()
-
- def _init_data_handler(self):
- handler_module = get_module_by_module_path(self.data_config.handler_module_path)
-
- # Set market
- market = self.data_config.handler_filter.get("market", None)
- if market is None:
- if "market" in self.data_config.handler_parameters:
- self.logger.warning(
- "Warning: The market in data.args section is deprecated. "
- "It only works when market is not set in data.filter section. "
- "It will be overridden by market in the data.filter section."
- )
- market = self.data_config.handler_parameters["market"]
- else:
- market = "csi500"
-
- self.data_config.handler_parameters["market"] = market
-
- data_filter_list = []
- handler_filters = self.data_config.handler_filter.get("filter_pipeline", list())
- for h_filter in handler_filters:
- filter_module_path = h_filter.get("module_path", "qlib.data.filter")
- filter_class_name = h_filter.get("class", "")
- filter_parameters = h_filter.get("args", {})
- filter_module = get_module_by_module_path(filter_module_path)
- filter_class = getattr(filter_module, filter_class_name)
- data_filter = filter_class(**filter_parameters)
- data_filter_list.append(data_filter)
-
- self.data_config.handler_parameters["data_filter_list"] = data_filter_list
- handler_class = getattr(handler_module, self.data_config.handler_class)
- self.data_handler = handler_class(**self.data_config.handler_parameters)
-
- def _init_trainer(self):
-
- model_module = get_module_by_module_path(self.model_config.model_module_path)
- trainer_module = get_module_by_module_path(self.trainer_config.trainer_module_path)
- model_class = getattr(model_module, self.model_config.model_class)
- trainer_class = getattr(trainer_module, self.trainer_config.trainer_class)
-
- self.trainer = trainer_class(
- model_class,
- self.model_config.save_path,
- self.model_config.parameters,
- self.data_handler,
- self.ex,
- **self.trainer_config.parameters
- )
-
- def _init_strategy(self):
-
- module = get_module_by_module_path(self.strategy_config.strategy_module_path)
- strategy_class = getattr(module, self.strategy_config.strategy_class)
- self.strategy = strategy_class(**self.strategy_config.parameters)
-
- def run(self):
- if self.ex_config.mode == ExperimentConfig.TRAIN_MODE:
- self.trainer.train()
- elif self.ex_config.mode == ExperimentConfig.TEST_MODE:
- self.trainer.load()
- else:
- raise ValueError("unexpected mode: %s" % self.ex_config.mode)
- analysis = self.backtest()
- print(analysis)
- self.logger.info(
- "experiment id: {}, experiment name: {}".format(self.ex.experiment.current_run._id, self.ex_config.name)
- )
-
- # Remove temp dir
- # shutil.rmtree(self.ex_config.tmp_run_dir)
-
- def backtest(self):
- TimeInspector.set_time_mark()
- # 1. Get pred and prediction score of model(s).
- pred = self.trainer.get_test_score()
- try:
- performance = self.trainer.get_test_performance()
- except NotImplementedError:
- performance = None
- # 2. Normal Backtest.
- report_normal, positions_normal = self._normal_backtest(pred)
- # 3. Long-Short Backtest.
- # Deprecated
- # long_short_reports = self._long_short_backtest(pred)
- # 4. Analyze
- analysis_df = self._analyze(report_normal)
- # 5. Save.
- self._save_backtest_result(
- pred,
- analysis_df,
- positions_normal,
- report_normal,
- # long_short_reports,
- performance,
- )
- return analysis_df
-
- def _normal_backtest(self, pred):
- TimeInspector.set_time_mark()
- if "account" not in self.backtest_config.normal_backtest_parameters:
- if "account" in self.strategy_config.parameters:
- self.logger.warning(
- "Warning: The account in strategy section is deprecated. "
- "It only works when account is not set in backtest section. "
- "It will be overridden by account in the backtest section."
- )
- self.backtest_config.normal_backtest_parameters["account"] = self.strategy_config.parameters["account"]
- report_normal, positions_normal = normal_backtest(
- pred, strategy=self.strategy, **self.backtest_config.normal_backtest_parameters
- )
- TimeInspector.log_cost_time("Finished normal backtest.")
- return report_normal, positions_normal
-
- def _long_short_backtest(self, pred):
- TimeInspector.set_time_mark()
- long_short_reports = long_short_backtest(pred, **self.backtest_config.long_short_backtest_parameters)
- TimeInspector.log_cost_time("Finished long-short backtest.")
- return long_short_reports
-
- @staticmethod
- def _analyze(report_normal):
- TimeInspector.set_time_mark()
-
- analysis = dict()
- # analysis["pred_long"] = risk_analysis(long_short_reports["long"])
- # analysis["pred_short"] = risk_analysis(long_short_reports["short"])
- # analysis["pred_long_short"] = risk_analysis(long_short_reports["long_short"])
- analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"]
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- TimeInspector.log_cost_time(
- "Finished generating analysis," " average turnover is: {0:.4f}.".format(report_normal["turnover"].mean())
- )
- return analysis_df
-
- def _save_backtest_result(self, pred, analysis, positions, report_normal, performance):
- # 1. Result dir.
- result_dir = os.path.join(self.config_manager.ex_config.tmp_run_dir, "result")
- if not os.path.exists(result_dir):
- os.makedirs(result_dir)
-
- self.ex.add_info(
- "task_config",
- json.loads(json.dumps(self.config_manager.config, default=str)),
- )
-
- # 2. Pred.
- TimeInspector.set_time_mark()
- pred_pkl_path = os.path.join(result_dir, "pred.pkl")
- pred.to_pickle(pred_pkl_path)
- self.ex.add_artifact(pred_pkl_path)
- TimeInspector.log_cost_time("Finished saving pred.pkl to: {}".format(pred_pkl_path))
-
- # 3. Ana.
- TimeInspector.set_time_mark()
- analysis_pkl_path = os.path.join(result_dir, "analysis.pkl")
- analysis.to_pickle(analysis_pkl_path)
- self.ex.add_artifact(analysis_pkl_path)
- TimeInspector.log_cost_time("Finished saving analysis.pkl to: {}".format(analysis_pkl_path))
-
- # 4. Pos.
- TimeInspector.set_time_mark()
- positions_pkl_path = os.path.join(result_dir, "positions.pkl")
- with open(positions_pkl_path, "wb") as fp:
- pickle.dump(positions, fp)
- self.ex.add_artifact(positions_pkl_path)
- TimeInspector.log_cost_time("Finished saving positions.pkl to: {}".format(positions_pkl_path))
-
- # 5. Report normal.
- TimeInspector.set_time_mark()
- report_normal_pkl_path = os.path.join(result_dir, "report_normal.pkl")
- report_normal.to_pickle(report_normal_pkl_path)
- self.ex.add_artifact(report_normal_pkl_path)
- TimeInspector.log_cost_time("Finished saving report_normal.pkl to: {}".format(report_normal_pkl_path))
-
- # 6. Report long short.
- # Deprecated
- # for k, name in zip(
- # ["long", "short", "long_short"],
- # ["report_long.pkl", "report_short.pkl", "report_long_short.pkl"],
- # ):
- # TimeInspector.set_time_mark()
- # pkl_path = os.path.join(result_dir, name)
- # long_short_reports[k].to_pickle(pkl_path)
- # self.ex.add_artifact(pkl_path)
- # TimeInspector.log_cost_time("Finished saving {} to: {}".format(name, pkl_path))
-
- # 7. Origin test label.
- TimeInspector.set_time_mark()
- label_pkl_path = os.path.join(result_dir, "label.pkl")
- self.data_handler.get_origin_test_label_with_date(
- self.trainer_config.parameters["test_start_date"],
- self.trainer_config.parameters["test_end_date"],
- ).to_pickle(label_pkl_path)
- self.ex.add_artifact(label_pkl_path)
- TimeInspector.log_cost_time("Finished saving label.pkl to: {}".format(label_pkl_path))
-
- # 8. Experiment info, save the model(s) performance here.
- TimeInspector.set_time_mark()
- cur_ex_id = self.ex.experiment.current_run._id
- exp_info = {
- "id": cur_ex_id,
- "name": self.ex_config.name,
- "performance": performance,
- "observer_type": self.ex_config.observer_type,
- }
-
- if self.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
- exp_info.update(
- {
- "mongo_url": self.ex_config.mongo_url,
- "db_name": self.ex_config.db_name,
- }
- )
- else:
- exp_info.update({"dir": self.ex_config.global_dir})
-
- with open(self.ex_config.exp_info_path, "w") as fp:
- json.dump(exp_info, fp, indent=4, sort_keys=True)
- self.ex.add_artifact(self.ex_config.exp_info_path)
- TimeInspector.log_cost_time("Finished saving ex_info to: {}".format(self.ex_config.exp_info_path))
-
- @staticmethod
- def compare_config_with_config_manger(config_manager):
- """Compare loader model args and current config with ConfigManage
-
- :param config_manager: ConfigManager
- :return:
- """
- fetcher = create_fetcher_with_config(config_manager, load_form_loader=True)
- loader_mode_config = fetcher.get_experiment(
- exp_name=config_manager.ex_config.loader_name,
- exp_id=config_manager.ex_config.loader_id,
- fields=["task_config"],
- )["task_config"]
- with open(config_manager.config_path) as fp:
- current_config = yaml.load(fp.read())
- current_config = json.loads(json.dumps(current_config, default=str))
-
- logger = get_module_logger("Estimator")
-
- loader_mode_config = copy.deepcopy(loader_mode_config)
- current_config = copy.deepcopy(current_config)
-
- # Require test_mode_config.test_start_date <= current_config.test_start_date
- loader_trainer_args = loader_mode_config.get("trainer", {}).get("args", {})
- cur_trainer_args = current_config.get("trainer", {}).get("args", {})
- loader_start_date = loader_trainer_args.pop("test_start_date")
- cur_test_start_date = cur_trainer_args.pop("test_start_date")
- assert (
- loader_start_date <= cur_test_start_date
- ), "Require: loader_mode_config.test_start_date <= current_config.test_start_date"
-
- # TODO: For the user's own extended `Trainer`, the support is not very good
- if "RollingTrainer" == current_config.get("trainer", {}).get("class", None):
- loader_period = loader_trainer_args.pop("rolling_period")
- cur_period = cur_trainer_args.pop("rolling_period")
- assert (
- loader_period == cur_period
- ), "Require: loader_mode_config.rolling_period == current_config.rolling_period"
-
- compare_section = ["trainer", "model", "data"]
- for section in compare_section:
- changes = compare_dict_value(loader_mode_config.get(section, {}), current_config.get(section, {}))
- if changes:
- logger.warning("Warning: Loader mode config and current config, `{}` are different:\n".format(section))
diff --git a/qlib/contrib/estimator/fetcher.py b/qlib/contrib/estimator/fetcher.py
deleted file mode 100644
index 16ef1dc60..000000000
--- a/qlib/contrib/estimator/fetcher.py
+++ /dev/null
@@ -1,290 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-# coding=utf-8
-
-import copy
-import json
-import yaml
-import pickle
-import gridfs
-import pymongo
-from pathlib import Path
-from abc import abstractmethod
-
-from .config import EstimatorConfigManager, ExperimentConfig
-
-
-class Fetcher(object):
- """Sacred Experiments Fetcher"""
-
- @abstractmethod
- def _get_experiment(self, exp_name, exp_id):
- """Get experiment basic info with experiment and experiment id
-
- :param exp_name: experiment name
- :param exp_id: experiment id
- :return: dict
- Must contain keys: _id, experiment, info, stop_time.
- Here is an example below for FileFetcher.
- exp = {
- '_id': exp_id, # experiment id
- 'path': path, # experiment result path
- 'experiment': {'name': exp_name}, # experiment
- 'info': info, # experiment config info
- 'stop_time': run.get('stop_time', None) # The time the experiment ended
- }
-
- """
- pass
-
- @abstractmethod
- def _list_experiments(self, exp_name=None):
- """Get experiment basic info list with experiment name
-
- :param exp_name: experiment name
- :return: list
-
- """
- pass
-
- @abstractmethod
- def _iter_artifacts(self, experiment):
- """Get information about the data in the experiment results
-
- :param experiment: `self._get_experiment` method result
- :return: iterable
- Each element contains two elements.
- first element : data name
- second element : data uri
- """
- pass
-
- @abstractmethod
- def _load_data(self, uri):
- """Load data with uri
-
- :param uri: data uri
- :return: bytes
- """
- pass
-
- @staticmethod
- def model_dict_to_buffer_list(model_dict):
- """
-
- :param model_dict:
- :return:
- """
- model_list = []
- is_static_model = False
- if len(model_dict) == 1 and list(model_dict.keys())[0] == "model.bin":
- is_static_model = True
- model_list.append(list(model_dict.values())[0])
- else:
- sep = "model.bin_"
- model_ids = list(map(lambda x: int(x.split(sep)[1]), model_dict.keys()))
- min_id, max_id = min(model_ids), max(model_ids)
- for i in range(min_id, max_id + 1):
- model_key = sep + str(i)
- model = model_dict.get(model_key, None)
- if model is None:
- print(
- "WARNING: In Fetcher, {} is missing when the get model is in the get_experiment function.".format(
- model_key
- )
- )
- break
- else:
- model_list.append(model)
-
- if is_static_model:
- return model_list[0]
-
- return model_list
-
- def get_experiments(self, exp_name=None):
- """Get experiments with name.
-
- :param exp_name: str
- If `exp_name` is set to None, then all experiments will return.
- :return: dict
- Experiments info dict(Including experiment id and task_config to run the
- experiment). Here is an example below.
- {
- 'a_experiment': [
- {
- 'id': '1',
- 'task_config': {...}
- },
- ...
- ]
- ...
- }
- """
- res = dict()
- for ex in self._list_experiments(exp_name):
- name = ex["experiment"]["name"]
- tmp = {
- "id": ex["_id"],
- "task_config": ex["info"].get("task_config", {}),
- "ex_run_stop_time": ex.get("stop_time", None),
- }
- res.setdefault(name, []).append(tmp)
- return res
-
- def get_experiment(self, exp_name, exp_id, fields=None):
- """
-
- :param exp_name:
- :param exp_id:
- :param fields: list
- Experiment result fields, if fields is None, will get all fields.
- Currently supported fields:
- ['model', 'analysis', 'positions', 'report_normal', 'pred', 'task_config', 'label']
- :return: dict
- """
- fields = copy.copy(fields)
- ex = self._get_experiment(exp_name, exp_id)
- results = dict()
- model_dict = dict()
- for name, uri in self._iter_artifacts(ex):
- # When saving, use `sacred.experiment.add_artifact(filename)` , so `name` is os.path.basename(filename)
- prefix = name.split(".")[0]
- if fields and prefix not in fields:
- continue
- data = self._load_data(uri)
- if prefix == "model":
- model_dict[name] = data
- else:
- results[prefix] = pickle.loads(data)
- # Sort model
- if model_dict:
- results["model"] = self.model_dict_to_buffer_list(model_dict)
-
- # Info
- results["task_config"] = ex["info"].get("task_config", {})
- return results
-
- def estimator_config_to_dict(self, exp_name, exp_id):
- """Save configuration to file
-
- :param exp_name:
- :param exp_id:
- :return: config dict
- """
-
- return self.get_experiment(exp_name, exp_id, fields=["task_config"])["task_config"]
-
-
-class FileFetcher(Fetcher):
- """File Fetcher"""
-
- def __init__(self, experiments_dir):
- self.experiments_dir = Path(experiments_dir)
-
- def _get_experiment(self, exp_name, exp_id):
- path = self.experiments_dir / exp_name / "sacred" / str(exp_id)
- info_path = path / "info.json"
- run_path = path / "run.json"
-
- if info_path.exists():
- with info_path.open("r") as f:
- info = json.load(f)
- else:
- info = {}
-
- if run_path.exists():
- with run_path.open("r") as f:
- run = json.load(f)
- else:
- run = {}
-
- exp = {
- "_id": exp_id,
- "path": path,
- "experiment": {"name": exp_name},
- "info": info,
- "stop_time": run.get("stop_time", None),
- }
- return exp
-
- def _list_experiments(self, exp_name=None):
- runs = []
- for path in self.experiments_dir.glob("{}/sacred/[!_]*".format(exp_name or "*")):
- exp_name, exp_id = path.parents[1].name, path.name
- runs.append(self._get_experiment(exp_name, exp_id))
- return runs
-
- def _iter_artifacts(self, experiment):
- if experiment is None:
- return []
-
- for fname in experiment["path"].iterdir():
- if fname.suffix == ".pkl" or ".bin" in fname.suffix:
- name, uri = fname.name, str(fname)
- yield name, uri
-
- def _load_data(self, uri):
- with open(uri, "rb") as f:
- data = f.read()
- return data
-
-
-class MongoFetcher(Fetcher):
- """MongoDB Fetcher"""
-
- def __init__(self, mongo_url, db_name):
- self.mongo_url = mongo_url
- self.db_name = db_name
- self.client = None
- self.db = None
- self.runs = None
- self.fs = None
- self._setup_mongo_client()
-
- def _setup_mongo_client(self):
- self.client = pymongo.MongoClient(self.mongo_url)
- self.db = self.client[self.db_name]
- self.runs = self.db.runs
- self.fs = gridfs.GridFS(self.db)
-
- def _get_experiment(self, exp_name, exp_id):
- return self.runs.find_one({"_id": exp_id})
-
- def _list_experiments(self, exp_name=None):
- if exp_name is None:
- return self.runs.find()
- return self.runs.find({"experiment.name": exp_name})
-
- def _iter_artifacts(self, experiment):
- if experiment is None:
- return []
- for artifact in experiment.get("artifacts", []):
- name, uri = artifact["name"], artifact["file_id"]
- yield name, uri
-
- def _load_data(self, uri):
- data = self.fs.get(uri).read()
- return data
-
-
-def create_fetcher_with_config(config_manager: EstimatorConfigManager, load_form_loader: bool = False):
- """Create fetcher with loader config
-
- :param config_manager:
- :param load_form_loader
- :return:
- """
- flag = ""
- if load_form_loader:
- flag = "loader_"
- if config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_FILE_STORAGE:
- return FileFetcher(eval("config_manager.ex_config.{}_dir".format("loader" if load_form_loader else "global")))
- elif config_manager.ex_config.observer_type == ExperimentConfig.OBSERVER_MONGO:
- return MongoFetcher(
- mongo_url=eval("config_manager.ex_config.{}mongo_url".format(flag)),
- db_name=eval("config_manager.ex_config.{}db_name".format(flag)),
- )
- else:
- return NotImplementedError("Unkown Backend")
diff --git a/qlib/contrib/estimator/launcher.py b/qlib/contrib/estimator/launcher.py
deleted file mode 100644
index 80717a32c..000000000
--- a/qlib/contrib/estimator/launcher.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-
-import argparse
-import importlib
-
-from ... import init
-from .config import EstimatorConfigManager
-from ...log import get_module_logger
-from sacred import Experiment
-from sacred.observers import FileStorageObserver
-from sacred.observers import MongoObserver
-
-args_parser = argparse.ArgumentParser(prog="estimator")
-args_parser.add_argument(
- "-c",
- "--config_path",
- required=True,
- type=str,
- help="json config path indicates where to load config.",
-)
-
-args = args_parser.parse_args()
-
-
-class SacredExperiment(object):
- def __init__(
- self,
- experiment_name,
- experiment_dir,
- observer_type="file_storage",
- mongo_url=None,
- db_name=None,
- ):
- """__init__
-
- :param experiment_name: The name of the experiments.
- :param experiment_dir: The directory to store all the results of the experiments(This is for file_storage).
- :param observer_type: The observer to record the results: the `file_storage` or `mongo`
- :param mongo_url: The mongo url(for mongo observer)
- :param db_name: The mongo url(for mongo observer)
- """
- self.experiment_name = experiment_name
- self.experiment = Experiment(self.experiment_name)
- self.experiment_dir = experiment_dir
- self.experiment.logger = get_module_logger("Sacred")
-
- self.observer_type = observer_type
- self.mongo_db_url = mongo_url
- self.mongo_db_name = db_name
-
- self._setup_experiment()
-
- def _setup_experiment(self):
- if self.observer_type == "file_storage":
- file_storage_observer = FileStorageObserver.create(basedir=self.experiment_dir)
- self.experiment.observers.append(file_storage_observer)
- elif self.observer_type == "mongo":
- mongo_observer = MongoObserver.create(url=self.mongo_db_url, db_name=self.mongo_db_name)
- self.experiment.observers.append(mongo_observer)
- else:
- raise NotImplementedError("Unsupported observer type: {}".format(self.observer_type))
-
- def add_artifact(self, filename):
- self.experiment.add_artifact(filename)
-
- def add_info(self, key, value):
- self.experiment.info[key] = value
-
- def main_wrapper(self, func):
- return self.experiment.main(func)
-
- def config_wrapper(self, func):
- return self.experiment.config(func)
-
-
-CONFIG_MANAGER = EstimatorConfigManager(args.config_path)
-
-ex = SacredExperiment(
- CONFIG_MANAGER.ex_config.name,
- CONFIG_MANAGER.ex_config.sacred_dir,
- observer_type=CONFIG_MANAGER.ex_config.observer_type,
- mongo_url=CONFIG_MANAGER.ex_config.mongo_url,
- db_name=CONFIG_MANAGER.ex_config.db_name,
-)
-
-# qlib init
-init(
- provider_uri=CONFIG_MANAGER.qlib_data_config.provider_uri,
- mount_path=CONFIG_MANAGER.qlib_data_config.mount_path,
- auto_mount=CONFIG_MANAGER.qlib_data_config.auto_mount,
- region=CONFIG_MANAGER.qlib_data_config.region,
- **CONFIG_MANAGER.qlib_data_config.args
-)
-
-
-@ex.main_wrapper
-def _main():
- # 1. Get estimator class.
- estimator_class = getattr(
- importlib.import_module(".estimator", package="qlib.contrib.estimator"),
- "Estimator",
- )
- # 2. Init estimator.
- estimator = estimator_class(CONFIG_MANAGER, ex)
- estimator.run()
-
-
-def run():
- ex.experiment.run()
-
-
-if __name__ == "__main__":
- run()
diff --git a/qlib/contrib/estimator/trainer.py b/qlib/contrib/estimator/trainer.py
deleted file mode 100644
index 84f387d67..000000000
--- a/qlib/contrib/estimator/trainer.py
+++ /dev/null
@@ -1,317 +0,0 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
-
-# coding=utf-8
-
-from abc import abstractmethod
-
-import pandas as pd
-import numpy as np
-from scipy.stats import pearsonr
-
-from ...log import get_module_logger, TimeInspector
-from ...data.dataset.handler import DataHandlerLP
-from .launcher import CONFIG_MANAGER
-from .fetcher import create_fetcher_with_config
-from ...utils import drop_nan_by_y_index, transform_end_date
-
-
-class BaseTrainer(object):
- def __init__(self, model_class, model_save_path, model_args, data_handler: DataHandlerLP, sacred_ex, **kwargs):
- # 1. Model.
- self.model_class = model_class
- self.model_save_path = model_save_path
- self.model_args = model_args
-
- # 2. Data handler.
- self.data_handler = data_handler
-
- # 3. Sacred ex.
- self.ex = sacred_ex
-
- # 4. Logger.
- self.logger = get_module_logger("Trainer")
-
- # 5. Data time
- self.train_start_date = kwargs.get("train_start_date", None)
- self.train_end_date = kwargs.get("train_end_date", None)
- self.validate_start_date = kwargs.get("validate_start_date", None)
- self.validate_end_date = kwargs.get("validate_end_date", None)
- self.test_start_date = kwargs.get("test_start_date", None)
- self.test_end_date = transform_end_date(kwargs.get("test_end_date", None))
-
- @abstractmethod
- def train(self):
- """
- Implement this method indicating how to train a model.
- """
- pass
-
- @abstractmethod
- def load(self):
- """
- Implement this method indicating how to restore a model and the data.
- """
- pass
-
- @abstractmethod
- def get_test_pred(self):
- """
- Implement this method indicating how to get prediction result(s) from a model.
- """
- pass
-
- def get_test_performance(self):
- """
- Implement this method indicating how to get the performance of the model.
- """
- raise NotImplementedError(f"Please implement `get_test_performance`")
-
- def get_test_score(self):
- """
- Override this method to transfer the predict result(s) into the score of the stock.
- Note: If this is a multi-label training, you need to transfer predict labels into one score.
- Or you can just use the result of `get_test_pred()` (you can also process the result) if this is one label training.
- We use the first column of the result of `get_test_pred()` as default method (regard it as one label training).
- """
- pred = self.get_test_pred()
- pred_score = pd.DataFrame(index=pred.index)
- pred_score["score"] = pred.iloc(axis=1)[0]
- return pred_score
-
-
-class StaticTrainer(BaseTrainer):
- def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
- super(StaticTrainer, self).__init__(model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs)
- self.model = None
-
- split_data = self.data_handler.get_split_data(
- self.train_start_date,
- self.train_end_date,
- self.validate_start_date,
- self.validate_end_date,
- self.test_start_date,
- self.test_end_date,
- )
- (
- self.x_train,
- self.y_train,
- self.x_validate,
- self.y_validate,
- self.x_test,
- self.y_test,
- ) = split_data
-
- def train(self):
- TimeInspector.set_time_mark()
- model = self.model_class(**self.model_args)
-
- if CONFIG_MANAGER.ex_config.finetune:
- fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
- loader_model = fetcher.get_experiment(
- exp_name=CONFIG_MANAGER.ex_config.loader_name,
- exp_id=CONFIG_MANAGER.ex_config.loader_id,
- fields=["model"],
- )["model"]
-
- if isinstance(loader_model, list):
- model_index = (
- -1
- if CONFIG_MANAGER.ex_config.loader_model_index is None
- else CONFIG_MANAGER.ex_config.loader_model_index
- )
- loader_model = loader_model[model_index]
-
- model.load(loader_model)
- model.finetune(self.x_train, self.y_train, self.x_validate, self.y_validate)
- else:
- model.fit(self.x_train, self.y_train, self.x_validate, self.y_validate)
- model.save(self.model_save_path)
- self.ex.add_artifact(self.model_save_path)
- self.model = model
- TimeInspector.log_cost_time("Finished training model.")
-
- def load(self):
- model = self.model_class(**self.model_args)
-
- # Load model
- fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
- loader_model = fetcher.get_experiment(
- exp_name=CONFIG_MANAGER.ex_config.loader_name,
- exp_id=CONFIG_MANAGER.ex_config.loader_id,
- fields=["model"],
- )["model"]
-
- if isinstance(loader_model, list):
- model_index = (
- -1
- if CONFIG_MANAGER.ex_config.loader_model_index is None
- else CONFIG_MANAGER.ex_config.loader_model_index
- )
- loader_model = loader_model[model_index]
-
- model.load(loader_model)
-
- # Save model, after load, if you don't save the model, the result of this experiment will be no model
- model.save(self.model_save_path)
- self.ex.add_artifact(self.model_save_path)
- self.model = model
-
- def get_test_pred(self):
- pred = self.model.predict(self.x_test)
- pred = pd.DataFrame(pred, index=self.x_test.index, columns=self.y_test.columns)
- return pred
-
- def get_test_performance(self):
- try:
- model_score = self.model.score(self.x_test, self.y_test)
- except NotImplementedError:
- model_score = None
- # Remove rows from x, y and w, which contain Nan in any columns in y_test.
- x_test, y_test, __ = drop_nan_by_y_index(self.x_test, self.y_test)
- pred_test = self.model.predict(x_test)
- model_pearsonr = pearsonr(np.ravel(pred_test), np.ravel(y_test.values))[0]
-
- performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
- return performance
-
-
-class RollingTrainer(BaseTrainer):
- def __init__(self, model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs):
- super(RollingTrainer, self).__init__(
- model_class, model_save_path, model_args, data_handler, sacred_ex, **kwargs
- )
- self.rolling_period = kwargs.get("rolling_period", 60)
- self.models = []
- self.rolling_data = []
- self.all_x_test = []
- self.all_y_test = []
- for data in self.data_handler.get_rolling_data(
- self.train_start_date,
- self.train_end_date,
- self.validate_start_date,
- self.validate_end_date,
- self.test_start_date,
- self.test_end_date,
- self.rolling_period,
- ):
- self.rolling_data.append(data)
- __, __, __, __, x_test, y_test = data
- self.all_x_test.append(x_test)
- self.all_y_test.append(y_test)
-
- def train(self):
- # 1. Get total data parts.
- # total_data_parts = self.data_handler.total_data_parts
- # self.logger.warning('Total numbers of model are: {}, start training models...'.format(total_data_parts))
- if CONFIG_MANAGER.ex_config.finetune:
- fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
- loader_model = fetcher.get_experiment(
- exp_name=CONFIG_MANAGER.ex_config.loader_name,
- exp_id=CONFIG_MANAGER.ex_config.loader_id,
- fields=["model"],
- )["model"]
- loader_model_index = CONFIG_MANAGER.ex_config.loader_model_index
- previous_model_path = ""
- # 2. Rolling train.
- for (
- index,
- (x_train, y_train, x_validate, y_validate, x_test, y_test),
- ) in enumerate(self.rolling_data):
- TimeInspector.set_time_mark()
- model = self.model_class(**self.model_args)
-
- if CONFIG_MANAGER.ex_config.finetune:
- # Finetune model
- if loader_model_index is None and isinstance(loader_model, list):
- try:
- model.load(loader_model[index])
- except IndexError:
- # Load model by previous_model_path
- with open(previous_model_path, "rb") as fp:
- model.load(fp)
- model.finetune(x_train, y_train, x_validate, y_validate)
- else:
-
- if index == 0:
- loader_model = (
- loader_model[loader_model_index] if isinstance(loader_model, list) else loader_model
- )
- model.load(loader_model)
- else:
- with open(previous_model_path, "rb") as fp:
- model.load(fp)
-
- model.finetune(x_train, y_train, x_validate, y_validate)
-
- else:
- model.fit(x_train, y_train, x_validate, y_validate)
-
- model_save_path = "{}_{}".format(self.model_save_path, index)
- model.save(model_save_path)
- previous_model_path = model_save_path
- self.ex.add_artifact(model_save_path)
- self.models.append(model)
- TimeInspector.log_cost_time("Finished training model: {}.".format(index + 1))
-
- def load(self):
- """
- Load the data and the model
- """
- fetcher = create_fetcher_with_config(CONFIG_MANAGER, load_form_loader=True)
- loader_model = fetcher.get_experiment(
- exp_name=CONFIG_MANAGER.ex_config.loader_name,
- exp_id=CONFIG_MANAGER.ex_config.loader_id,
- fields=["model"],
- )["model"]
- for index in range(len(self.all_x_test)):
- model = self.model_class(**self.model_args)
-
- model.load(loader_model[index])
-
- # Save model
- model_save_path = "{}_{}".format(self.model_save_path, index)
- model.save(model_save_path)
- self.ex.add_artifact(model_save_path)
-
- self.models.append(model)
-
- def get_test_pred(self):
- """
- Predict the score on test data with the models.
- Please ensure the models and data are loaded before call this score.
-
- :return: the predicted scores for the pred
- """
- pred_df_list = []
- y_test_columns = self.all_y_test[0].columns
- # Start iteration.
- for model, x_test in zip(self.models, self.all_x_test):
- pred = model.predict(x_test)
- pred_df = pd.DataFrame(pred, index=x_test.index, columns=y_test_columns)
- pred_df_list.append(pred_df)
- return pd.concat(pred_df_list)
-
- def get_test_performance(self):
- """
- Get the performances of the models
-
- :return: the performances of models
- """
- pred_test_list = []
- y_test_list = []
- scorer = self.models[0]._scorer
- for model, x_test, y_test in zip(self.models, self.all_x_test, self.all_y_test):
- # Remove rows from x, y and w, which contain Nan in any columns in y_test.
- x_test, y_test, __ = drop_nan_by_y_index(x_test, y_test)
- pred_test_list.append(model.predict(x_test))
- y_test_list.append(np.squeeze(y_test.values))
-
- pred_test_array = np.concatenate(pred_test_list, axis=0)
- y_test_array = np.concatenate(y_test_list, axis=0)
-
- model_score = scorer(y_test_array, pred_test_array)
- model_pearsonr = pearsonr(np.ravel(y_test_array), np.ravel(pred_test_array))[0]
-
- performance = {"model_score": model_score, "model_pearsonr": model_pearsonr}
- return performance
diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py
index cf1793c93..4bb5e4372 100644
--- a/qlib/contrib/evaluate.py
+++ b/qlib/contrib/evaluate.py
@@ -26,9 +26,9 @@ def risk_analysis(r, N=252):
Parameters
----------
r : pandas.Series
- daily return series
+ daily return series.
N: int
- scaler for annualizing information_ratio (day: 250, week: 50, month: 12)
+ scaler for annualizing information_ratio (day: 250, week: 50, month: 12).
"""
mean = r.mean()
std = r.std(ddof=1)
@@ -61,7 +61,7 @@ def get_strategy(
----------
strategy : Strategy()
- strategy used in backtest
+ strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
@@ -73,14 +73,14 @@ def get_strategy(
sell_limit = pred_in_a_day.count() * margin
- buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit)
- sell_limit should be no less than topk
+ buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
+ sell_limit should be no less than topk.
n_drop : int
- number of stocks to be replaced in each trading date
+ number of stocks to be replaced in each trading date.
risk_degree: float
- 0-1, 0.95 for example, use 95% money to trade
+ 0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
- strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy
+ strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
Returns
-------
@@ -126,21 +126,21 @@ def get_exchange(
----------
# exchange related arguments
- exchange: Exchange()
+ exchange: Exchange().
subscribe_fields: list
- subscribe fields
+ subscribe fields.
open_cost : float
- open transaction cost
+ open transaction cost.
close_cost : float
- close transaction cost
+ close transaction cost.
min_cost : float
- min transaction cost
+ min transaction cost.
trade_unit : int
- 100 for China A
+ 100 for China A.
deal_price: str
- dealing price type: 'close', 'open', 'vwap'
+ dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
- limit move 0.1 (10%) for example, long and short with same limit
+ limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
@@ -193,20 +193,20 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
- **backtest workflow related or commmon arguments**
pred : pandas.DataFrame
- predict should has index and one `score` column
+ predict should has index and one `score` column.
account : float
- init account value
+ init account value.
shift : int
- whether to shift prediction by one day
+ whether to shift prediction by one day.
benchmark : str
- benchmark code, default is SH000905 CSI 500
+ benchmark code, default is SH000905 CSI 500.
verbose : bool
- whether to print log
+ whether to print log.
- **strategy related arguments**
strategy : Strategy()
- strategy used in backtest
+ strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
@@ -218,33 +218,33 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
sell_limit = pred_in_a_day.count() * margin
- buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit)
- sell_limit should be no less than topk
+ buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
+ sell_limit should be no less than topk.
n_drop : int
- number of stocks to be replaced in each trading date
+ number of stocks to be replaced in each trading date.
risk_degree: float
- 0-1, 0.95 for example, use 95% money to trade
+ 0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
- strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy
+ strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
- **exchange related arguments**
-
+
exchange: Exchange()
pass the exchange for speeding up.
subscribe_fields: list
- subscribe fields
+ subscribe fields.
open_cost : float
open transaction cost. The default value is 0.002(0.2%).
close_cost : float
close transaction cost. The default value is 0.002(0.2%).
min_cost : float
- min transaction cost
+ min transaction cost.
trade_unit : int
- 100 for China A
+ 100 for China A.
deal_price: str
- dealing price type: 'close', 'open', 'vwap'
+ dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
- limit move 0.1 (10%) for example, long and short with same limit
+ limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
@@ -291,17 +291,17 @@ def long_short_backtest(
"""
A backtest for long-short strategy
- :param pred: The trading signal produced on day `T`
- :param topk: The short topk securities and long topk securities
- :param deal_price: The price to deal the trading
+ :param pred: The trading signal produced on day `T`.
+ :param topk: The short topk securities and long topk securities.
+ :param deal_price: The price to deal the trading.
:param shift: Whether to shift prediction by one day. The trading day will be T+1 if shift==1.
- :param open_cost: open transaction cost
- :param close_cost: close transaction cost
- :param trade_unit: 100 for China A
- :param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit
- :param min_cost: min transaction cost
- :param subscribe_fields: subscribe fields
- :param extract_codes: bool
+ :param open_cost: open transaction cost.
+ :param close_cost: close transaction cost.
+ :param trade_unit: 100 for China A.
+ :param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit.
+ :param min_cost: min transaction cost.
+ :param subscribe_fields: subscribe fields.
+ :param extract_codes: bool.
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
:return: The result of backtest, it is represented by a dict.
diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py
index e487a6d1e..bba006c35 100644
--- a/qlib/contrib/model/catboost_model.py
+++ b/qlib/contrib/model/catboost_model.py
@@ -1,3 +1,15 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import numpy as np
import pandas as pd
from catboost import Pool, CatBoost
diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py
new file mode 100644
index 000000000..8f5ddc486
--- /dev/null
+++ b/qlib/contrib/model/pytorch_alstm.py
@@ -0,0 +1,349 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import pandas as pd
+import copy
+from sklearn.metrics import roc_auc_score, mean_squared_error
+import logging
+from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
+from ...log import get_module_logger, TimeInspector
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class ALSTM(Model):
+ """ALSTM Model
+
+ Parameters
+ ----------
+ d_feat : int
+ input dimension for each time step
+ metric: str
+ the evaluate metric used in early stop
+ optimizer : str
+ optimizer name
+ GPU : str
+ the GPU ID(s) used for training
+ """
+
+ def __init__(
+ self,
+ d_feat=6,
+ hidden_size=64,
+ num_layers=2,
+ dropout=0.0,
+ n_epochs=200,
+ lr=0.001,
+ metric="",
+ batch_size=2000,
+ early_stop=20,
+ loss="mse",
+ optimizer="adam",
+ GPU="0",
+ seed=0,
+ **kwargs
+ ):
+ # Set logger.
+ self.logger = get_module_logger("ALSTM")
+ self.logger.info("ALSTM pytorch version...")
+
+ # set hyper-parameters.
+ self.d_feat = d_feat
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.n_epochs = n_epochs
+ self.lr = lr
+ self.metric = metric
+ self.batch_size = batch_size
+ self.early_stop = early_stop
+ self.optimizer = optimizer.lower()
+ self.loss = loss
+ self.visible_GPU = GPU
+ self.use_gpu = torch.cuda.is_available()
+ self.seed = seed
+
+ self.logger.info(
+ "ALSTM parameters setting:"
+ "\nd_feat : {}"
+ "\nhidden_size : {}"
+ "\nnum_layers : {}"
+ "\ndropout : {}"
+ "\nn_epochs : {}"
+ "\nlr : {}"
+ "\nmetric : {}"
+ "\nbatch_size : {}"
+ "\nearly_stop : {}"
+ "\noptimizer : {}"
+ "\nloss_type : {}"
+ "\nvisible_GPU : {}"
+ "\nuse_GPU : {}"
+ "\nseed : {}".format(
+ d_feat,
+ hidden_size,
+ num_layers,
+ dropout,
+ n_epochs,
+ lr,
+ metric,
+ batch_size,
+ early_stop,
+ optimizer.lower(),
+ loss,
+ GPU,
+ self.use_gpu,
+ seed,
+ )
+ )
+
+ self.ALSTM_model = ALSTMModel(
+ d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
+ )
+ if optimizer.lower() == "adam":
+ self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
+ elif optimizer.lower() == "gd":
+ self.train_optimizer = optim.SGD(self.ALSTM_model.parameters(), lr=self.lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
+
+ self._fitted = False
+ if self.use_gpu:
+ self.ALSTM_model.cuda()
+ # set the visible GPU
+ if self.visible_GPU:
+ os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
+
+ def mse(self, pred, label):
+ loss = (pred - label) ** 2
+ return torch.mean(loss)
+
+ def loss_fn(self, pred, label):
+ mask = ~torch.isnan(label)
+
+ if self.loss == "mse":
+ return self.mse(pred[mask], label[mask])
+
+ raise ValueError("unknown loss `%s`" % self.loss)
+
+ def metric_fn(self, pred, label):
+
+ mask = torch.isfinite(label)
+
+ if self.metric == "" or self.metric == "loss": # use loss
+ return -self.loss_fn(pred[mask], label[mask])
+
+ raise ValueError("unknown metric `%s`" % self.metric)
+
+
+ def train_epoch(self, x_train, y_train):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+
+ self.ALSTM_model.train()
+
+ indices = np.arange(len(x_train_values))
+ np.random.shuffle(indices)
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float()
+ label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.ALSTM_model(feature)
+ loss = self.loss_fn(pred, label)
+
+ self.train_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.ALSTM_model.parameters(), 3.0)
+ self.train_optimizer.step()
+
+ def test_epoch(self, data_x, data_y):
+
+ # prepare training data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.ALSTM_model.eval()
+
+ scores = []
+ losses = []
+
+ indices = np.arange(len(x_values))
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float()
+ label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.ALSTM_model(feature)
+ loss = self.loss_fn(pred, label)
+ losses.append(loss.item())
+
+ score = self.metric_fn(pred, label)
+ scores.append(score.item())
+
+ return np.mean(losses), np.mean(scores)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
+
+ df_train, df_valid, df_test = dataset.prepare(
+ ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ if save_path == None:
+ save_path = create_save_path(save_path)
+ stop_steps = 0
+ train_loss = 0
+ best_score = -np.inf
+ best_epoch = 0
+ evals_result["train"] = []
+ evals_result["valid"] = []
+
+ # train
+ self.logger.info("training...")
+ self._fitted = True
+
+ for step in range(self.n_epochs):
+ self.logger.info("Epoch%d:", step)
+ self.logger.info("training...")
+ self.train_epoch(x_train, y_train)
+ self.logger.info("evaluating...")
+ train_loss, train_score = self.test_epoch(x_train, y_train)
+ val_loss, val_score = self.test_epoch(x_valid, y_valid)
+ self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
+ evals_result["train"].append(train_score)
+ evals_result["valid"].append(val_score)
+
+ if val_score > best_score:
+ best_score = val_score
+ stop_steps = 0
+ best_epoch = step
+ best_param = copy.deepcopy(self.ALSTM_model.state_dict())
+ else:
+ stop_steps += 1
+ if stop_steps >= self.early_stop:
+ self.logger.info("early stop")
+ break
+
+ self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
+ self.ALSTM_model.load_state_dict(best_param)
+ torch.save(best_param, save_path)
+
+ if self.use_gpu:
+ torch.cuda.empty_cache()
+
+ def predict(self, dataset):
+ if not self._fitted:
+ raise ValueError("model is not fitted yet!")
+
+ x_test = dataset.prepare("test", col_set="feature")
+ index = x_test.index
+ self.ALSTM_model.eval()
+ x_values = x_test.values
+ sample_num = x_values.shape[0]
+ preds = []
+
+ for begin in range(sample_num)[:: self.batch_size]:
+
+ if sample_num - begin < self.batch_size:
+ end = sample_num
+ else:
+ end = begin + self.batch_size
+
+ x_batch = torch.from_numpy(x_values[begin:end]).float()
+
+ if self.use_gpu:
+ x_batch = x_batch.cuda()
+
+ with torch.no_grad():
+ if self.use_gpu:
+ pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
+ else:
+ pred = self.ALSTM_model(x_batch).detach().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
+
+
+class ALSTMModel(nn.Module):
+ def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type="GRU"):
+ super().__init__()
+ self.hid_size = hidden_size
+ self.input_size = d_feat
+ self.dropout = dropout
+ self.rnn_type = rnn_type
+ self.rnn_layer = num_layers
+ self._build_model()
+
+ def _build_model(self):
+ try:
+ klass = getattr(nn, self.rnn_type.upper())
+ except:
+ raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
+ self.net = nn.Sequential()
+ self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
+ self.net.add_module("act", nn.Tanh())
+ self.rnn = klass(
+ input_size=self.hid_size,
+ hidden_size=self.hid_size,
+ num_layers=self.rnn_layer,
+ batch_first=True,
+ dropout=self.dropout,
+ )
+ self.fc_out = nn.Linear(in_features=self.hid_size * 2, out_features=1)
+ self.att_net = nn.Sequential()
+ self.att_net.add_module("att_fc_in", nn.Linear(in_features=self.hid_size, out_features=int(self.hid_size / 2)))
+ self.att_net.add_module("att_dropout", torch.nn.Dropout(self.dropout))
+ self.att_net.add_module("att_act", nn.Tanh())
+ self.att_net.add_module("att_fc_out", nn.Linear(in_features=int(self.hid_size / 2), out_features=1, bias=False))
+ self.att_net.add_module("att_softmax", nn.Softmax(dim=1))
+
+ def forward(self, inputs):
+ # inputs: [batch_size, input_size*input_day]
+ inputs = inputs.view(len(inputs), self.input_size, -1)
+ inputs = inputs.permute(0, 2, 1) # [batch, input_size, seq_len] -> [batch, seq_len, input_size]
+ rnn_out, _ = self.rnn(self.net(inputs)) # [batch, seq_len, num_directions * hidden_size]
+ attention_score = self.att_net(rnn_out) # [batch, seq_len, 1]
+ out_att = torch.mul(rnn_out, attention_score)
+ out_att = torch.sum(out_att, dim=1)
+ out = self.fc_out(
+ torch.cat((rnn_out[:, -1, :], out_att), dim=1)
+ ) # [batch, seq_len, num_directions * hidden_size] -> [batch, 1]
+ return out[..., 0]
diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py
index 77e3b9de9..77a02a9b2 100755
--- a/qlib/contrib/model/pytorch_gats.py
+++ b/qlib/contrib/model/pytorch_gats.py
@@ -9,10 +9,8 @@ import os
import numpy as np
import pandas as pd
import copy
-from sklearn.metrics import roc_auc_score, mean_squared_error
-import logging
-from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
-from ...log import get_module_logger, TimeInspector
+from ...utils import create_save_path
+from ...log import get_module_logger
import torch
import torch.nn as nn
@@ -28,14 +26,12 @@ class GAT(Model):
Parameters
----------
- input_dim : int
- input dimension
- output_dim : int
- output dimension
- layers : tuple
- layer sizes
lr : float
learning rate
+ d_feat : int
+ input dimensions for each time step
+ metric : str
+ the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
@@ -50,8 +46,7 @@ class GAT(Model):
dropout=0.0,
n_epochs=200,
lr=0.001,
- metric="IC",
- batch_size=2000,
+ metric="",
early_stop=20,
loss="mse",
base_model="GRU",
@@ -73,7 +68,6 @@ class GAT(Model):
self.n_epochs = n_epochs
self.lr = lr
self.metric = metric
- self.batch_size = batch_size
self.early_stop = early_stop
self.optimizer = optimizer.lower()
self.loss = loss
@@ -92,7 +86,6 @@ class GAT(Model):
"\nn_epochs : {}"
"\nlr : {}"
"\nmetric : {}"
- "\nbatch_size : {}"
"\nearly_stop : {}"
"\noptimizer : {}"
"\nloss_type : {}"
@@ -108,7 +101,6 @@ class GAT(Model):
n_epochs,
lr,
metric,
- batch_size,
early_stop,
optimizer.lower(),
loss,
@@ -120,10 +112,6 @@ class GAT(Model):
)
)
- if loss not in {"mse", "binary"}:
- raise NotImplementedError("loss {} is not supported!".format(loss))
- self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
-
self.GAT_model = GATModel(
d_feat=self.d_feat,
hidden_size=self.hidden_size,
@@ -160,34 +148,37 @@ class GAT(Model):
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
- if self.metric == "IC":
- return self.cal_ic(pred[mask], label[mask])
if self.metric == "" or self.metric == "loss": # use loss
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
- def cal_ic(self, pred, label):
- return torch.mean(pred * label)
+ def get_daily_inter(self, df, shuffle=False):
+ # organize the train data into daily inter as daily batches
+ daily_count = df.groupby(level=0).size().values
+ daily_index = np.roll(np.cumsum(daily_count), 1)
+ daily_index[0] = 0
+ if shuffle:
+ # shuffle the daily inter data
+ daily_shuffle = list(zip(daily_index, daily_count))
+ np.random.shuffle(daily_shuffle)
+ daily_index, daily_count = zip(*daily_shuffle)
+ return daily_index, daily_count
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
- y_train_values = np.squeeze(y_train.values) * 100
-
+ y_train_values = np.squeeze(y_train.values)
self.GAT_model.train()
- indices = np.arange(len(x_train_values))
- np.random.shuffle(indices)
+ # organize the train data into daily inter as daily batches
+ daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
- for i in range(len(indices))[:: self.batch_size]:
-
- if len(indices) - i < self.batch_size:
- break
-
- feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float()
- label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float()
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ feature = torch.from_numpy(x_train_values[batch]).float()
+ label = torch.from_numpy(y_train_values[batch]).float()
if self.use_gpu:
feature = feature.cuda()
@@ -212,16 +203,13 @@ class GAT(Model):
scores = []
losses = []
- indices = np.arange(len(x_values))
- np.random.shuffle(indices)
+ # organize the test data into daily inter as daily batches
+ daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
- for i in range(len(indices))[:: self.batch_size]:
-
- if len(indices) - i < self.batch_size:
- break
-
- feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float()
- label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float()
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ feature = torch.from_numpy(x_values[batch]).float()
+ label = torch.from_numpy(y_values[batch]).float()
if self.use_gpu:
feature = feature.cuda()
@@ -254,7 +242,6 @@ class GAT(Model):
if save_path == None:
save_path = create_save_path(save_path)
stop_steps = 0
- train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
@@ -265,12 +252,14 @@ class GAT(Model):
self.logger.info("Loading pretrained model...")
if self.base_model == "LSTM":
from ...contrib.model.pytorch_lstm import LSTMModel
+
pretrained_model = LSTMModel()
- pretrained_model.load_state_dict(torch.load('benchmarks/LSTM/model_lstm_csi300.pkl'))
+ pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
elif self.base_model == "GRU":
from ...contrib.model.pytorch_gru import GRUModel
+
pretrained_model = GRUModel()
- pretrained_model.load_state_dict(torch.load('benchmarks/GRU/model_gru_csi300.pkl'))
+ pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
model_dict.update(pretrained_dict)
@@ -319,17 +308,14 @@ class GAT(Model):
index = x_test.index
self.GAT_model.eval()
x_values = x_test.values
- sample_num = x_values.shape[0]
preds = []
- for begin in range(sample_num)[:: self.batch_size]:
+ # organize the data into daily inter as daily batches
+ daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
- if sample_num - begin < self.batch_size:
- end = sample_num
- else:
- end = begin + self.batch_size
-
- x_batch = torch.from_numpy(x_values[begin:end]).float()
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ x_batch = torch.from_numpy(x_values[batch]).float()
if self.use_gpu:
x_batch = x_batch.cuda()
@@ -375,7 +361,6 @@ class GATModel(nn.Module):
self.fc_out = nn.Linear(hidden_size, 1)
self.leaky_relu = nn.LeakyReLU()
self.softmax = nn.Softmax(dim=1)
-
self.d_feat = d_feat
def cal_convariance(self, x, y): # the 2nd dimension of x and y are the same
@@ -394,12 +379,7 @@ class GATModel(nn.Module):
out, _ = self.rnn(x)
hidden = out[:, -1, :]
hidden = self.bn1(hidden)
-
gamma = self.cal_convariance(hidden, hidden)
- # gamma = hidden.mm(torch.t(hidden))
- # gamma = self.leaky_relu(gamma)
- # gamma = self.softmax(gamma)
- # gamma = gamma * (torch.ones(x.shape[0], x.shape[0]).to(device) - torch.diag(torch.ones(x.shape[0])).to(device))
output = gamma.mm(hidden)
output = self.fc(output)
output = self.bn2(output)
diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py
index 4cc7f9852..02664b6ac 100755
--- a/qlib/contrib/model/pytorch_gru.py
+++ b/qlib/contrib/model/pytorch_gru.py
@@ -28,14 +28,10 @@ class GRU(Model):
Parameters
----------
- input_dim : int
- input dimension
- output_dim : int
- output dimension
- layers : tuple
- layer sizes
- lr : float
- learning rate
+ d_feat : int
+ input dimension for each time step
+ metric: str
+ the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
@@ -50,7 +46,7 @@ class GRU(Model):
dropout=0.0,
n_epochs=200,
lr=0.001,
- metric="IC",
+ metric="",
batch_size=2000,
early_stop=20,
loss="mse",
@@ -112,10 +108,6 @@ class GRU(Model):
)
)
- if loss not in {"mse", "binary"}:
- raise NotImplementedError("loss {} is not supported!".format(loss))
- self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
-
self.gru_model = GRUModel(
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
)
@@ -148,21 +140,16 @@ class GRU(Model):
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
- if self.metric == "IC":
- return self.cal_ic(pred[mask], label[mask])
if self.metric == "" or self.metric == "loss": # use loss
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
- def cal_ic(self, pred, label):
- return torch.mean(pred * label)
-
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
- y_train_values = np.squeeze(y_train.values) * 100
+ y_train_values = np.squeeze(y_train.values)
self.gru_model.train()
@@ -201,7 +188,6 @@ class GRU(Model):
losses = []
indices = np.arange(len(x_values))
- np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]:
@@ -251,7 +237,6 @@ class GRU(Model):
# train
self.logger.info("training...")
self._fitted = True
- # return
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
diff --git a/qlib/contrib/model/pytorch_hats.py b/qlib/contrib/model/pytorch_hats.py
new file mode 100644
index 000000000..7affea73c
--- /dev/null
+++ b/qlib/contrib/model/pytorch_hats.py
@@ -0,0 +1,491 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import pandas as pd
+import copy
+from ...utils import create_save_path
+from ...log import get_module_logger
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class HATS(Model):
+ """HATS Model
+
+ Parameters
+ ----------
+ d_feat : int
+ input dimension for each time step
+ metric: str
+ the evaluate metric used in early stop
+ optimizer : str
+ optimizer name
+ GPU : str
+ the GPU ID(s) used for training
+ """
+
+ def __init__(
+ self,
+ d_feat=6,
+ hidden_size=64,
+ num_layers=2,
+ dropout=0.5,
+ n_epochs=200,
+ lr=0.01,
+ metric="",
+ early_stop=20,
+ loss="mse",
+ base_model="GRU",
+ with_pretrain=True,
+ optimizer="adam",
+ GPU="0",
+ seed=0,
+ **kwargs
+ ):
+ # Set logger.
+ self.logger = get_module_logger("HATS")
+ self.logger.info("HATS pytorch version...")
+
+ # set hyper-parameters.
+ self.d_feat = d_feat
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.n_epochs = n_epochs
+ self.lr = lr
+ self.metric = metric
+ self.early_stop = early_stop
+ self.optimizer = optimizer.lower()
+ self.loss = loss
+ self.base_model = base_model
+ self.with_pretrain = with_pretrain
+ self.visible_GPU = GPU
+ self.use_gpu = torch.cuda.is_available()
+ self.seed = seed
+
+ self.logger.info(
+ "HATS parameters setting:"
+ "\nd_feat : {}"
+ "\nhidden_size : {}"
+ "\nnum_layers : {}"
+ "\ndropout : {}"
+ "\nn_epochs : {}"
+ "\nlr : {}"
+ "\nmetric : {}"
+ "\nearly_stop : {}"
+ "\noptimizer : {}"
+ "\nloss_type : {}"
+ "\nbase_model : {}"
+ "\nwith_pretrain : {}"
+ "\nvisible_GPU : {}"
+ "\nuse_GPU : {}"
+ "\nseed : {}".format(
+ d_feat,
+ hidden_size,
+ num_layers,
+ dropout,
+ n_epochs,
+ lr,
+ metric,
+ early_stop,
+ optimizer.lower(),
+ loss,
+ base_model,
+ with_pretrain,
+ GPU,
+ self.use_gpu,
+ seed,
+ )
+ )
+
+ self.HATS_model = HATSModel(
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
+ base_model=self.base_model,
+ )
+ if optimizer.lower() == "adam":
+ self.train_optimizer = optim.Adam(self.HATS_model.parameters(), lr=self.lr)
+ elif optimizer.lower() == "gd":
+ self.train_optimizer = optim.SGD(self.HATS_model.parameters(), lr=self.lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
+
+ self._fitted = False
+ if self.use_gpu:
+ self.HATS_model.cuda()
+ # set the visible GPU
+ if self.visible_GPU:
+ os.environ["CUDA_VISIBLE_DEVICES"] = self.visible_GPU
+
+ def mse(self, pred, label):
+ loss = (pred - label) ** 2
+ return torch.mean(loss)
+
+ def loss_fn(self, pred, label):
+ mask = ~torch.isnan(label)
+
+ if self.loss == "mse":
+ return self.mse(pred[mask], label[mask])
+
+ raise ValueError("unknown loss `%s`" % self.loss)
+
+ def metric_fn(self, pred, label):
+ mask = torch.isfinite(label)
+
+ if self.metric == "" or self.metric == "loss": # use loss
+ return -self.loss_fn(pred[mask], label[mask])
+
+ raise ValueError("unknown metric `%s`" % self.metric)
+
+ def get_daily_inter(self, df, shuffle=False):
+ # organize the train data into daily inter as daily batches
+ daily_count = df.groupby(level=0).size().values
+ daily_index = np.roll(np.cumsum(daily_count), 1)
+ daily_index[0] = 0
+ if shuffle:
+ # shuffle the daily inter data
+ daily_shuffle = list(zip(daily_index, daily_count))
+ np.random.shuffle(daily_shuffle)
+ daily_index, daily_count = zip(*daily_shuffle)
+ return daily_index, daily_count
+
+ def train_epoch(self, x_train, y_train):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+
+ self.HATS_model.train()
+
+ # organize the train data into daily inter as daily batches
+ daily_index, daily_count = self.get_daily_inter(x_train, shuffle=True)
+
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ feature = torch.from_numpy(x_train_values[batch]).float()
+ label = torch.from_numpy(y_train_values[batch]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.HATS_model(feature)
+ loss = self.loss_fn(pred, label)
+
+ self.train_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.HATS_model.parameters(), 3.0)
+ self.train_optimizer.step()
+
+ def test_epoch(self, data_x, data_y):
+
+ # prepare testing data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.HATS_model.eval()
+
+ scores = []
+ losses = []
+
+ # organize the test data into daily inter as daily batches
+ daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
+
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ feature = torch.from_numpy(x_values[batch]).float()
+ label = torch.from_numpy(y_values[batch]).float()
+
+ if self.use_gpu:
+ feature = feature.cuda()
+ label = label.cuda()
+
+ pred = self.HATS_model(feature)
+ loss = self.loss_fn(pred, label)
+ losses.append(loss.item())
+
+ score = self.metric_fn(pred, label)
+ scores.append(score.item())
+
+ return np.mean(losses), np.mean(scores)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
+
+ df_train, df_valid, df_test = dataset.prepare(
+ ["train", "valid", "test"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
+ )
+
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+
+ if save_path == None:
+ save_path = create_save_path(save_path)
+ stop_steps = 0
+ best_score = -np.inf
+ best_epoch = 0
+ evals_result["train"] = []
+ evals_result["valid"] = []
+
+ # load pretrained base_model
+ if self.with_pretrain:
+ self.logger.info("Loading pretrained model...")
+ if self.base_model == "LSTM":
+ from ...contrib.model.pytorch_lstm import LSTMModel
+
+ pretrained_model = LSTMModel()
+ pretrained_model.load_state_dict(torch.load("benchmarks/LSTM/model_lstm_csi300.pkl"))
+ elif self.base_model == "GRU":
+ from ...contrib.model.pytorch_gru import GRUModel
+
+ pretrained_model = GRUModel()
+ pretrained_model.load_state_dict(torch.load("benchmarks/GRU/model_gru_csi300.pkl"))
+ model_dict = self.HATS_model.state_dict()
+ pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
+ model_dict.update(pretrained_dict)
+ self.HATS_model.load_state_dict(model_dict)
+ self.logger.info("Loading pretrained model Done...")
+
+ # train
+ self.logger.info("training...")
+ self._fitted = True
+
+ for step in range(self.n_epochs):
+ self.logger.info("Epoch%d:", step)
+ self.logger.info("training...")
+ self.train_epoch(x_train, y_train)
+ self.logger.info("evaluating...")
+ train_loss, train_score = self.test_epoch(x_train, y_train)
+ val_loss, val_score = self.test_epoch(x_valid, y_valid)
+ self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
+ evals_result["train"].append(train_score)
+ evals_result["valid"].append(val_score)
+
+ if val_score > best_score:
+ best_score = val_score
+ stop_steps = 0
+ best_epoch = step
+ best_param = copy.deepcopy(self.HATS_model.state_dict())
+ else:
+ stop_steps += 1
+ if stop_steps >= self.early_stop:
+ self.logger.info("early stop")
+ break
+
+ self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
+ self.HATS_model.load_state_dict(best_param)
+ torch.save(best_param, save_path)
+
+ if self.use_gpu:
+ torch.cuda.empty_cache()
+
+ def predict(self, dataset):
+ if not self._fitted:
+ raise ValueError("model is not fitted yet!")
+
+ x_test = dataset.prepare("test", col_set="feature")
+ index = x_test.index
+ self.HATS_model.eval()
+ x_values = x_test.values
+ sample_num = x_values.shape[0]
+ preds = []
+
+ # organize the data into daily inter as daily batches
+ daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
+
+ for idx, count in zip(daily_index, daily_count):
+ batch = slice(idx, idx + count)
+ x_batch = torch.from_numpy(x_values[batch]).float()
+
+ if self.use_gpu:
+ x_batch = x_batch.cuda()
+
+ with torch.no_grad():
+ if self.use_gpu:
+ pred = self.HATS_model(x_batch).detach().cpu().numpy()
+ else:
+ pred = self.HATS_model(x_batch).detach().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
+
+
+class HATSModel(nn.Module):
+ def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"):
+ super().__init__()
+
+ if base_model == "GRU":
+ self.model = nn.GRU(
+ input_size=d_feat,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=dropout,
+ )
+ elif base_model == "LSTM":
+ self.model = nn.LSTM(
+ input_size=d_feat,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=dropout,
+ )
+ else:
+ raise ValueError("unknown base model name `%s`" % base_model)
+
+ self.hidden_size = hidden_size
+ self.bn1 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
+ self.fc = nn.Linear(hidden_size, hidden_size)
+ self.bn2 = nn.BatchNorm1d(num_features=hidden_size, track_running_stats=False)
+ self.fc_out = nn.Linear(hidden_size, 1)
+ self.leaky_relu = nn.LeakyReLU()
+ self.softmax = nn.Softmax(dim=1)
+ self.d_feat = d_feat
+
+ num_head_att = [1] * num_layers
+ hidden_dim = [hidden_size] * num_layers
+ dims = [d_feat] + [d * nh for (d, nh) in zip(hidden_dim, num_head_att[:-1])] + [num_head_att[-1]]
+ in_dims = dims[:-1]
+ out_dims = [d // nh for (d, nh) in zip(dims[1:], num_head_att)]
+ self.attn = nn.ModuleList(
+ [GraphAttention(i, o, nh, dropout) for (i, o, nh) in zip(in_dims, out_dims, num_head_att)]
+ )
+ self.bns = nn.ModuleList([nn.BatchNorm1d(dim) for dim in dims[1:-1]])
+ self.dropout = nn.Dropout(dropout)
+ self.elu = nn.ELU()
+
+ def forward(self, x):
+ x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
+ x = x.permute(0, 2, 1) # [N, T, F]
+ out, _ = self.model(x)
+ hidden = out[:, -1, :]
+ hidden = self.bn1(hidden)
+ attention = GraphAttention.cal_attention(hidden, hidden)
+ output = attention.mm(hidden)
+ output = self.fc(output)
+ output = self.bn2(output)
+ output = self.leaky_relu(output)
+ return self.fc_out(output).squeeze()
+
+
+class GraphAttention(nn.Module):
+ def __init__(self, input_dim, output_dim, num_heads, dropout=0.5):
+
+ super().__init__()
+
+ """
+ Parameters
+ ----------
+ input_dim : int
+ Dimension of input node features.
+ output_dim : int
+ Dimension of output node features.
+ num_heads : list of ints
+ Number of attention heads in each hidden layer and output layer. Must be non empty. Note that len(num_heads) = len(hidden_dims)+1.
+ dropout : float
+ Dropout rate. Default: 0.5.
+ """
+
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.num_heads = num_heads
+
+ self.fcs = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_heads)])
+ self.a = nn.ModuleList([nn.Linear(2 * output_dim, 1) for _ in range(num_heads)])
+
+ self.dropout = nn.Dropout(dropout)
+ self.softmax = nn.Softmax(dim=0)
+ self.leakyrelu = nn.LeakyReLU()
+
+ def forward(self, features, nodes, mappings, rows):
+
+ """
+ Parameters
+ ----------
+ features : torch.Tensor
+ An (n' x input_dim) tensor of input node features.
+ nodes : list of numpy array
+ nodes[i] is an array of the nodes in the ith layer of the
+ computation graph.
+ mappings : list of dictionary
+ mappings[i] is a dictionary mappings node v (labelled 0 to |V|-1)
+ in nodes[i] to its position in nodes[i]. For example,
+ if nodes[i] = [2,5], then mappings[i][2] = 0 and
+ mappings[i][5] = 1.
+ rows : numpy array
+ rows[i] is an array of neighbors of node i.
+ Returns
+ -------
+ out : torch.Tensor
+ An (len(node_layers[-1]) x output_dim) tensor of output node features.
+ """
+
+ nprime = features.shape[0]
+ rows = [np.array([mappings[v] for v in row], dtype=np.int64) for row in rows]
+ sum_degs = np.hstack(([0], np.cumsum([len(row) for row in rows])))
+ mapped_nodes = [mappings[v] for v in nodes]
+ indices = torch.LongTensor([[v, c] for (v, row) in zip(mapped_nodes, rows) for c in row]).t()
+
+ out = []
+ for k in range(self.num_heads):
+ h = self.fcs[k](features)
+
+ nbr_h = torch.cat(tuple([h[row] for row in rows]), dim=0)
+ self_h = torch.cat(
+ tuple([h[mappings[nodes[i]]].repeat(len(row), 1) for (i, row) in enumerate(rows)]), dim=0
+ )
+ cat_h = torch.cat((self_h, nbr_h), dim=1)
+
+ e = self.leakyrelu(self.a[k](cat_h))
+
+ alpha = [self.softmax(e[lo:hi]) for (lo, hi) in zip(sum_degs, sum_degs[1:])]
+ alpha = torch.cat(tuple(alpha), dim=0)
+ alpha = alpha.squeeze(1)
+ alpha = self.dropout(alpha)
+
+ adj = torch.sparse.FloatTensor(indices, alpha, torch.Size([nprime, nprime]))
+ out.append(torch.sparse.mm(adj, h)[mapped_nodes])
+
+ return out
+
+ @staticmethod
+ def cal_attention(x, y):
+ att_x = torch.mean(x, dim=1).reshape(-1, 1)
+ att_y = torch.mean(y, dim=1).reshape(-1, 1)
+ att = att_x.mm(torch.t(att_y))
+ return (
+ torch.mean(
+ x.reshape(x.shape[0], 1, x.shape[1]).repeat(1, y.shape[0], 1)
+ * y.reshape(1, y.shape[0], y.shape[1]).repeat(x.shape[0], 1, 1),
+ dim=2,
+ )
+ - att
+ )
diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py
index 8b8454380..f8951509a 100755
--- a/qlib/contrib/model/pytorch_lstm.py
+++ b/qlib/contrib/model/pytorch_lstm.py
@@ -28,14 +28,10 @@ class LSTM(Model):
Parameters
----------
- input_dim : int
- input dimension
- output_dim : int
- output dimension
- layers : tuple
- layer sizes
- lr : float
- learning rate
+ d_feat : int
+ input dimension for each time step
+ metric: str
+ the evaluate metric used in early stop
optimizer : str
optimizer name
GPU : str
@@ -50,7 +46,7 @@ class LSTM(Model):
dropout=0.0,
n_epochs=200,
lr=0.001,
- metric="IC",
+ metric="",
batch_size=2000,
early_stop=20,
loss="mse",
@@ -112,10 +108,6 @@ class LSTM(Model):
)
)
- if loss not in {"mse", "binary"}:
- raise NotImplementedError("loss {} is not supported!".format(loss))
- self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
-
self.lstm_model = LSTMModel(
d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout
)
@@ -148,21 +140,16 @@ class LSTM(Model):
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
- if self.metric == "IC":
- return self.cal_ic(pred[mask], label[mask])
if self.metric == "" or self.metric == "loss": # use loss
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
- def cal_ic(self, pred, label):
- return torch.mean(pred * label)
-
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
- y_train_values = np.squeeze(y_train.values) * 100
+ y_train_values = np.squeeze(y_train.values)
self.lstm_model.train()
@@ -201,7 +188,6 @@ class LSTM(Model):
losses = []
indices = np.arange(len(x_values))
- np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]:
@@ -251,7 +237,6 @@ class LSTM(Model):
# train
self.logger.info("training...")
self._fitted = True
- # return
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py
index 8564c491c..1d27f3927 100644
--- a/qlib/contrib/model/pytorch_sfm.py
+++ b/qlib/contrib/model/pytorch_sfm.py
@@ -1,5 +1,15 @@
# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
from __future__ import division
from __future__ import print_function
@@ -90,10 +100,7 @@ class SFM_Model(nn.Module):
x_c = torch.matmul(x * B_W[0], self.W_c) + self.b_c
x_o = torch.matmul(x * B_W[0], self.W_o) + self.b_o
- i = self.inner_activation(
- x_i + torch.matmul(h_tm1 * B_U[0], self.U_i)
- ) # not sure whether I am doing in the right unsquuze
-
+ i = self.inner_activation(x_i + torch.matmul(h_tm1 * B_U[0], self.U_i))
ste = self.inner_activation(x_ste + torch.matmul(h_tm1 * B_U[0], self.U_ste))
fre = self.inner_activation(x_fre + torch.matmul(h_tm1 * B_U[0], self.U_fre))
@@ -173,10 +180,6 @@ class SFM(Model):
output dimension
lr : float
learning rate
- lr_decay : float
- learning rate decay
- lr_decay_steps : int
- learning rate decay steps
optimizer : str
optimizer name
GPU : str
@@ -193,12 +196,11 @@ class SFM(Model):
dropout_U=0.0,
n_epochs=200,
lr=0.001,
+ metric="",
batch_size=2000,
early_stop=20,
eval_steps=5,
loss="mse",
- lr_decay=0.96,
- lr_decay_steps=100,
optimizer="gd",
GPU="0",
seed=0,
@@ -217,13 +219,12 @@ class SFM(Model):
self.dropout_U = dropout_U
self.n_epochs = n_epochs
self.lr = lr
+ self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.eval_steps = eval_steps
- self.lr_decay = lr_decay
- self.lr_decay_steps = lr_decay_steps
self.optimizer = optimizer.lower()
- self.loss_type = loss
+ self.loss = loss
self.device = "cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu"
self.use_gpu = torch.cuda.is_available()
self.seed = seed
@@ -232,16 +233,16 @@ class SFM(Model):
"SFM parameters setting:"
"\nd_feat : {}"
"\nhidden_size : {}"
+ "\noutput_size : {}"
"\nfrequency_dimension : {}"
"\ndropout_W: {}"
"\ndropout_U: {}"
"\nn_epochs : {}"
"\nlr : {}"
+ "\nmetric : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\neval_steps : {}"
- "\nlr_decay : {}"
- "\nlr_decay_steps : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\nvisible_GPU : {}"
@@ -249,16 +250,16 @@ class SFM(Model):
"\nseed : {}".format(
d_feat,
hidden_size,
+ output_dim,
freq_dim,
dropout_W,
dropout_U,
n_epochs,
lr,
+ metric,
batch_size,
early_stop,
eval_steps,
- lr_decay,
- lr_decay_steps,
optimizer.lower(),
loss,
GPU,
@@ -267,10 +268,6 @@ class SFM(Model):
)
)
- if loss not in {"mse", "binary"}:
- raise NotImplementedError("loss {} is not supported!".format(loss))
- self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
-
self.sfm_model = SFM_Model(
d_feat=self.d_feat,
output_dim=self.output_dim,
@@ -287,24 +284,72 @@ class SFM(Model):
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
- # Reduce learning rate when loss has stopped decrease
- self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
- self.train_optimizer,
- mode="min",
- factor=0.5,
- patience=10,
- verbose=True,
- threshold=0.0001,
- threshold_mode="rel",
- cooldown=0,
- min_lr=0.00001,
- eps=1e-08,
- )
-
self._fitted = False
self.sfm_model.to(self.device)
- def fit(self, dataset: DatasetH, evals_result=dict(), verbose=True, save_path=None, **kwargs):
+ def test_epoch(self, data_x, data_y):
+
+ # prepare training data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.sfm_model.eval()
+
+ scores = []
+ losses = []
+
+ indices = np.arange(len(x_values))
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
+ label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
+
+ pred = self.sfm_model(feature)
+ loss = self.loss_fn(pred, label)
+ losses.append(loss.item())
+
+ score = self.metric_fn(pred, label)
+ scores.append(score.item())
+
+ return np.mean(losses), np.mean(scores)
+
+ def train_epoch(self, x_train, y_train):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+
+ self.sfm_model.train()
+
+ indices = np.arange(len(x_train_values))
+ np.random.shuffle(indices)
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
+ label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
+
+ pred = self.sfm_model(feature)
+ loss = self.loss_fn(pred, label)
+
+ self.train_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.sfm_model.parameters(), 3.0)
+ self.train_optimizer.step()
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
@@ -312,10 +357,10 @@ class SFM(Model):
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
- save_path = create_save_path(save_path)
stop_steps = 0
train_loss = 0
- best_loss = np.inf
+ best_score = -np.inf
+ best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
@@ -323,90 +368,51 @@ class SFM(Model):
self.logger.info("training...")
self._fitted = True
- # prepare training data
- x_train_values = torch.from_numpy(x_train.values).float()
- y_train_values = torch.from_numpy(np.squeeze(y_train.values)).float()
- train_num = y_train_values.shape[0]
-
- # prepare validation data
- x_val_auto = torch.from_numpy(x_valid.values).float()
- y_val_auto = torch.from_numpy(np.squeeze(y_valid.values)).float()
-
- x_val_auto = x_val_auto.to(self.device)
- y_val_auto = y_val_auto.to(self.device)
-
for step in range(self.n_epochs):
- if stop_steps >= self.early_stop:
- if verbose:
- self.logger.info("\tearly stop")
- break
- loss = AverageMeter()
- self.sfm_model.train()
- self.train_optimizer.zero_grad()
+ self.logger.info("Epoch%d:", step)
+ self.logger.info("training...")
+ self.train_epoch(x_train, y_train)
+ self.logger.info("evaluating...")
+ train_loss, train_score = self.test_epoch(x_train, y_train)
+ val_loss, val_score = self.test_epoch(x_valid, y_valid)
+ self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
+ evals_result["train"].append(train_score)
+ evals_result["valid"].append(val_score)
- choice = np.random.choice(train_num, self.batch_size)
- x_batch_auto = x_train_values[choice]
- y_batch_auto = y_train_values[choice]
-
- x_batch_auto = x_batch_auto.to(self.device)
- y_batch_auto = y_batch_auto.to(self.device)
-
- # forward
- preds = self.sfm_model(x_batch_auto)
- cur_loss = self.get_loss(preds, y_batch_auto, self.loss_type)
- cur_loss.backward()
- self.train_optimizer.step()
- loss.update(cur_loss.item())
-
- # validation
- train_loss += loss.val
- # print(loss.val)
- if step and step % self.eval_steps == 0:
+ if val_score > best_score:
+ best_score = val_score
+ stop_steps = 0
+ best_epoch = step
+ best_param = copy.deepcopy(self.sfm_model.state_dict())
+ else:
stop_steps += 1
- train_loss /= self.eval_steps
-
- with torch.no_grad():
- self.sfm_model.eval()
- loss_val = AverageMeter()
-
- # forward
- preds = self.sfm_model(x_val_auto)
- cur_loss_val = self.get_loss(preds, y_val_auto, self.loss_type)
- loss_val.update(cur_loss_val.item())
-
- if verbose:
- self.logger.info(
- "[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
- )
- evals_result["train"].append(train_loss)
- evals_result["valid"].append(loss_val.val)
- if loss_val.val < best_loss:
- if verbose:
- self.logger.info(
- "\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format(
- best_loss, loss_val.val
- )
- )
- best_loss = loss_val.val
- stop_steps = 0
- torch.save(self.sfm_model.state_dict(), save_path)
- train_loss = 0
- # update learning rate
- self.scheduler.step(cur_loss_val)
-
+ if stop_steps >= self.early_stop:
+ self.logger.info("early stop")
+ break
+ self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
if self.device != "cpu":
torch.cuda.empty_cache()
- def get_loss(self, pred, target, loss_type):
- if loss_type == "mse":
- sqr_loss = (pred - target) ** 2
- loss = sqr_loss.mean()
- return loss
- elif loss_type == "binary":
- loss = nn.BCELoss()
- return loss(pred, target)
- else:
- raise NotImplementedError("loss {} is not supported!".format(loss_type))
+ def mse(self, pred, label):
+ loss = (pred - label) ** 2
+ return torch.mean(loss)
+
+ def loss_fn(self, pred, label):
+ mask = ~torch.isnan(label)
+
+ if self.loss == "mse":
+ return self.mse(pred[mask], label[mask])
+
+ raise ValueError("unknown loss `%s`" % self.loss)
+
+ def metric_fn(self, pred, label):
+
+ mask = torch.isfinite(label)
+
+ if self.metric == "" or self.metric == "loss": # use loss
+ return -self.loss_fn(pred[mask], label[mask])
+
+ raise ValueError("unknown metric `%s`" % self.metric)
def predict(self, dataset):
if not self._fitted:
@@ -414,34 +420,28 @@ class SFM(Model):
x_test = dataset.prepare("test", col_set="feature")
index = x_test.index
- x_test = torch.from_numpy(x_test.values).float()
-
- x_test = x_test.to(self.device)
self.sfm_model.eval()
+ x_values = x_test.values
+ sample_num = x_values.shape[0]
+ preds = []
- with torch.no_grad():
- if self.device != "cpu":
- preds = self.sfm_model(x_test).detach().cpu().numpy()
+ for begin in range(sample_num)[:: self.batch_size]:
+ if sample_num - begin < self.batch_size:
+ end = sample_num
else:
- preds = self.sfm_model(x_test).detach().numpy()
- return pd.Series(preds, index=index)
+ end = begin + self.batch_size
- def save(self, filename, **kwargs):
- with save_multiple_parts_file(filename) as model_dir:
- model_path = os.path.join(model_dir, os.path.split(model_dir)[-1])
- # Save model
- torch.save(self.sfm_model.state_dict(), model_path)
+ x_batch = torch.from_numpy(x_values[begin:end]).float()
- def load(self, buffer, **kwargs):
- with unpack_archive_with_buffer(buffer) as model_dir:
- # Get model name
- _model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[
- 0
- ]
- _model_path = os.path.join(model_dir, _model_name)
- # Load model
- self.sfm_model.load_state_dict(torch.load(_model_path))
- self._fitted = True
+ if self.device != "cpu":
+ x_batch = x_batch.to(self.device)
+
+ with torch.no_grad():
+ pred = self.sfm_model(x_batch).detach().cpu().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
class AverageMeter(object):
diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py
index e0691ba16..039fd2c80 100755
--- a/qlib/contrib/model/xgboost.py
+++ b/qlib/contrib/model/xgboost.py
@@ -1,5 +1,14 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import numpy as np
import pandas as pd
@@ -13,10 +22,8 @@ from ...data.dataset.handler import DataHandlerLP
class XGBModel(Model):
"""XGBModel Model"""
- def __init__(self, obj="mse", **kwargs):
- if obj not in {"mse", "binary"}:
- raise NotImplementedError
- self._params = {"obj": obj}
+ def __init__(self, **kwargs):
+ self._params = {}
self._params.update(kwargs)
self.model = None
diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py
index 1c69145db..1cb14d261 100644
--- a/qlib/contrib/report/analysis_model/analysis_model_performance.py
+++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py
@@ -252,7 +252,7 @@ def model_performance_graph(
"""Model performance
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score,
- label]**. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1")
+ label]**. It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1").
.. code-block:: python
@@ -266,13 +266,13 @@ def model_performance_graph(
:param lag: `pred.groupby(level='instrument')['score'].shift(lag)`. It will be only used in the auto-correlation computing.
- :param N: group number, default 5
- :param reverse: if `True`, `pred['score'] *= -1`
- :param rank: if **True**, calculate rank ic
- :param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover']
- :param show_notebook: whether to display graphics in notebook, the default is `True`
- :param show_nature_day: whether to display the abscissa of non-trading day
- :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list
+ :param N: group number, default 5.
+ :param reverse: if `True`, `pred['score'] *= -1`.
+ :param rank: if **True**, calculate rank ic.
+ :param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'].
+ :param show_notebook: whether to display graphics in notebook, the default is `True`.
+ :param show_nature_day: whether to display the abscissa of non-trading day.
+ :return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list.
"""
figure_list = []
for graph_name in graph_names:
diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py
index 941785e83..abb68ea60 100644
--- a/qlib/contrib/report/analysis_position/cumulative_return.py
+++ b/qlib/contrib/report/analysis_position/cumulative_return.py
@@ -218,10 +218,10 @@ def cumulative_return_graph(
Graph desc:
- - Axis X: Trading day
+ - Axis X: Trading day.
- Axis Y:
- - Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`
- - Below axis Y: Daily weight sum
+ - Above axis Y: `(((Ref($close, -1)/$close - 1) * weight).sum() / weight.sum()).cumsum()`.
+ - Below axis Y: Daily weight sum.
- In the **sell** graph, `y < 0` stands for profit; in other cases, `y > 0` stands for profit.
- In the **buy_minus_sell** graph, the **y** value of the **weight** graph at the bottom is `buy_weight + sell_weight`.
- In each graph, the **red line** in the histogram on the right represents the average.
diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py
index e2f7fe1cf..72a358adc 100644
--- a/qlib/contrib/report/analysis_position/rank_label.py
+++ b/qlib/contrib/report/analysis_position/rank_label.py
@@ -97,9 +97,9 @@ def rank_label_graph(
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
- :param position: position data; **qlib.contrib.backtest.backtest.backtest** result
+ :param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
:param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**.
- **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`
+ **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`.
.. code-block:: python
@@ -115,7 +115,7 @@ def rank_label_graph(
:param start_date: start date
:param end_date: end_date
- :param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures
+ :param show_notebook: **True** or **False**. If True, show graph in notebook, else return figures.
:return:
"""
position = copy.deepcopy(position)
diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py
index 6d108cabf..438aab8b9 100644
--- a/qlib/contrib/report/analysis_position/report.py
+++ b/qlib/contrib/report/analysis_position/report.py
@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
qcr.report_graph(report_normal_df)
- :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**
+ :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.
.. code-block:: python
@@ -200,8 +200,8 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
- :param show_notebook: whether to display graphics in notebook, the default is **True**
- :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list
+ :param show_notebook: whether to display graphics in notebook, the default is **True**.
+ :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
"""
report_df = report_df.copy()
fig_list = _report_figure(report_df)
diff --git a/qlib/contrib/report/analysis_position/risk_analysis.py b/qlib/contrib/report/analysis_position/risk_analysis.py
index 124a9b3b0..051c78035 100644
--- a/qlib/contrib/report/analysis_position/risk_analysis.py
+++ b/qlib/contrib/report/analysis_position/risk_analysis.py
@@ -218,7 +218,7 @@ def risk_analysis_graph(
max_drawdown -0.088263
- :param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**
+ :param report_normal_df: **df.index.name** must be **date**, df.columns must contain **return**, **turnover**, **cost**, **bench**.
.. code-block:: python
@@ -232,7 +232,7 @@ def risk_analysis_graph(
2017-01-10 -0.000416 0.000440 -0.003350 0.208396
- :param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**
+ :param report_long_short_df: **df.index.name** must be **date**, df.columns contain **long**, **short**, **long_short**.
.. code-block:: python
@@ -246,7 +246,7 @@ def risk_analysis_graph(
2017-01-10 0.000824 -0.001944 -0.001120
- :param show_notebook: Whether to display graphics in a notebook, default **True**
+ :param show_notebook: Whether to display graphics in a notebook, default **True**.
If True, show graph in notebook
If False, return graph figure
:return:
diff --git a/qlib/contrib/report/analysis_position/score_ic.py b/qlib/contrib/report/analysis_position/score_ic.py
index 9a2fc8560..a6a7a8b0e 100644
--- a/qlib/contrib/report/analysis_position/score_ic.py
+++ b/qlib/contrib/report/analysis_position/score_ic.py
@@ -36,7 +36,7 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
analysis_position.score_ic_graph(pred_label)
- :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**
+ :param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**.
.. code-block:: python
@@ -49,8 +49,8 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
2017-12-15 -0.102778 -0.102778
- :param show_notebook: whether to display graphics in notebook, the default is **True**
- :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list
+ :param show_notebook: whether to display graphics in notebook, the default is **True**.
+ :return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
"""
_ic_df = _get_score_ic(pred_label)
# FIXME: support HIGH-FREQ
diff --git a/qlib/contrib/strategy/strategy.py b/qlib/contrib/strategy/strategy.py
index f2e2a4554..23e8b5185 100644
--- a/qlib/contrib/strategy/strategy.py
+++ b/qlib/contrib/strategy/strategy.py
@@ -31,16 +31,16 @@ class BaseStrategy:
Parameters
-----------
score_series : pd.Seires
- stock_id , score
+ stock_id , score.
current : Position()
- current state of position
- DO NOT directly change the state of current
+ current state of position.
+ DO NOT directly change the state of current.
trade_exchange : Exchange()
- trade exchange
+ trade exchange.
pred_date : pd.Timestamp
- predict date
+ predict date.
trade_date : pd.Timestamp
- trade date
+ trade date.
"""
pass
@@ -49,11 +49,11 @@ class BaseStrategy:
Parameters
-----------
score_series : pd.Series
- stock_id , score
+ stock_id , score.
pred_date : pd.Timestamp
- oredict date
+ oredict date.
trade_date : pd.Timestamp
- trade date
+ trade date.
"""
pass
@@ -67,7 +67,7 @@ class BaseStrategy:
"""
This method only be used in 'online' module, it will generate the *args to initial the strategy.
:param
- mode : model used in 'online' module
+ mode : model used in 'online' module.
"""
return {}
@@ -82,7 +82,7 @@ class StrategyWrapper:
def __init__(self, inner_strategy):
"""__init__
- :param inner_strategy: set the inner strategy
+ :param inner_strategy: set the inner strategy.
"""
self.inner_strategy = inner_strategy
@@ -99,9 +99,9 @@ class AdjustTimer:
Responsible for timing of position adjusting
This is designed as multiple inheritance mechanism due to:
- - the is_adjust may need access to the internel state of a strategy
+ - the is_adjust may need access to the internel state of a strategy.
- - it can be reguard as a enhancement to the existing strategy
+ - it can be reguard as a enhancement to the existing strategy.
"""
# adjust position in each trade date
@@ -146,12 +146,12 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
Parameters
-----------
score : pd.Series
- pred score for this trade date, index is stock_id, contain 'score' column
+ pred score for this trade date, index is stock_id, contain 'score' column.
current : Position()
- current position
+ current position.
trade_exchange : Exchange()
trade_date : pd.Timestamp
- trade date
+ trade date.
"""
raise NotImplementedError()
@@ -160,13 +160,13 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
Parameters
-----------
score_series : pd.Seires
- stock_id , score
+ stock_id , score.
current : Position()
- current of account
+ current of account.
trade_exchange : Exchange()
- exchange
+ exchange.
trade_date : pd.Timestamp
- date
+ date.
"""
# judge if to adjust
if not self.is_adjust(trade_date):
@@ -206,26 +206,26 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
Parameters
-----------
topk : int
- The number of stocks in the portfolio
+ the number of stocks in the portfolio.
n_drop : int
- number of stocks to be replaced in each trading date
+ number of stocks to be replaced in each trading date.
method_sell : str
- dropout method_sell, random/bottom
+ dropout method_sell, random/bottom.
method_buy : str
- dropout method_buy, random/top
+ dropout method_buy, random/top.
risk_degree : float
- position percentage of total value
+ position percentage of total value.
thresh : int
- minimun holding days since last buy singal of the stock
+ minimun holding days since last buy singal of the stock.
hold_thresh : int
minimum holding days
- before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh
+ before sell stock , will check current.get_stock_count(order.stock_id) >= self.thresh.
only_tradable : bool
will the strategy only consider the tradable stock when buying and selling.
if only_tradable:
- strategy will make buy sell decision without checking the tradable state of the stock
+ strategy will make buy sell decision without checking the tradable state of the stock.
else:
- strategy will make decision with the tradable state of the stock info and avoid buy and sell them
+ strategy will make decision with the tradable state of the stock info and avoid buy and sell them.
"""
super(TopkDropoutStrategy, self).__init__()
ListAdjustTimer.__init__(self, kwargs.get("adjust_dates", None))
@@ -245,7 +245,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
def get_risk_degree(self, date):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
- Dynamically risk_degree will result in Market timing
+ Dynamically risk_degree will result in Market timing.
"""
# It will use 95% amoutn of your total value by default
return self.risk_degree
@@ -257,15 +257,15 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
Parameters
-----------
score_series : pd.Series
- stock_id , score
+ stock_id , score.
current : Position()
- current of account
+ current of account.
trade_exchange : Exchange()
- exchange
+ exchange.
pred_date : pd.Timestamp
- predict date
+ predict date.
trade_date : pd.Timestamp
- trade date
+ trade date.
"""
if not self.is_adjust(trade_date):
return []
diff --git a/qlib/data/base.py b/qlib/data/base.py
index 433b6585a..92fc57ffe 100644
--- a/qlib/data/base.py
+++ b/qlib/data/base.py
@@ -129,13 +129,13 @@ class Expression(abc.ABC):
Parameters
----------
instrument : str
- instrument code
+ instrument code.
start_index : str
- feature start index [in calendar]
+ feature start index [in calendar].
end_index : str
- feature end index [in calendar]
+ feature end index [in calendar].
freq : str
- feature frequency
+ feature frequency.
Returns
----------
diff --git a/qlib/data/cache.py b/qlib/data/cache.py
index bf8baab31..3fab2b527 100644
--- a/qlib/data/cache.py
+++ b/qlib/data/cache.py
@@ -76,8 +76,8 @@ class MemCache(object):
Parameters
----------
- mem_cache_size_limit: cache max size
- limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof)
+ mem_cache_size_limit: cache max size.
+ limit_type: length or sizeof; length(call fun: len), size(call fun: sys.getsizeof).
"""
if limit_type not in ["length", "sizeof"]:
raise ValueError(f"limit_type must be length or sizeof, your limit_type is {limit_type}")
@@ -118,9 +118,9 @@ class MemCacheExpire:
def set_cache(mem_cache, key, value):
"""set cache
- :param mem_cache: MemCache attribute('c'/'i'/'f')
- :param key: cache key
- :param value: cache value
+ :param mem_cache: MemCache attribute('c'/'i'/'f').
+ :param key: cache key.
+ :param value: cache value.
"""
mem_cache[key] = value, time.time()
@@ -128,9 +128,9 @@ class MemCacheExpire:
def get_cache(mem_cache, key):
"""get mem cache
- :param mem_cache: MemCache attribute('c'/'i'/'f')
- :param key: cache key
- :return: cache value; if cache not exist, return None
+ :param mem_cache: MemCache attribute('c'/'i'/'f').
+ :param key: cache key.
+ :return: cache value; if cache not exist, return None.
"""
value = None
expire = False
@@ -275,12 +275,12 @@ class ExpressionCache(BaseProviderCache):
Parameters
----------
cache_uri : str
- the complete uri of expression cache file (include dir path)
+ the complete uri of expression cache file (include dir path).
Returns
-------
int
- 0(successful update)/ 1(no need to update)/ 2(update failure)
+ 0(successful update)/ 1(no need to update)/ 2(update failure).
"""
raise NotImplementedError("Implement this method if you want to make expression cache up to date")
@@ -348,7 +348,7 @@ class DatasetCache(BaseProviderCache):
Parameters
----------
cache_uri : str
- the complete uri of dataset cache file (include dir path)
+ the complete uri of dataset cache file (include dir path).
Returns
-------
@@ -361,9 +361,9 @@ class DatasetCache(BaseProviderCache):
def cache_to_origin_data(data, fields):
"""cache data to origin data
- :param data: pd.DataFrame, cache data
- :param fields: feature fields
- :return: pd.DataFrame
+ :param data: pd.DataFrame, cache data.
+ :param fields: feature fields.
+ :return: pd.DataFrame.
"""
not_space_fields = remove_fields_space(fields)
data = data.loc[:, not_space_fields]
@@ -583,7 +583,7 @@ class DiskDatasetCache(DatasetCache):
:param cache_path:
:param start_time:
:param end_time:
- :param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent
+ :param fields: The fields order of the dataset cache is sorted. So rearrange the columns to make it consistent.
:return:
"""
@@ -771,12 +771,12 @@ class DiskDatasetCache(DatasetCache):
- This is a hdf file sorted by datetime
- :param cache_path: The path to store the cache
- :param instruments: The instruments to store the cache
- :param fields: The fields to store the cache
- :param freq: The freq to store the cache
+ :param cache_path: The path to store the cache.
+ :param instruments: The instruments to store the cache.
+ :param fields: The fields to store the cache.
+ :param freq: The freq to store the cache.
- :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function
+ :return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
"""
# get calendar
from .data import Cal
diff --git a/qlib/data/client.py b/qlib/data/client.py
index 928faaa72..65a830f20 100644
--- a/qlib/data/client.py
+++ b/qlib/data/client.py
@@ -51,13 +51,13 @@ class Client(object):
Parameters
----------
request_type : str
- type of proposed request, 'calendar'/'instrument'/'feature'
+ type of proposed request, 'calendar'/'instrument'/'feature'.
request_content : dict
- records the information of the request
+ records the information of the request.
msg_proc_func : func
- the function to process the message when receiving response, should have arg `*args`
+ the function to process the message when receiving response, should have arg `*args`.
msg_queue: Queue
- The queue to pass the messsage after callback
+ The queue to pass the messsage after callback.
"""
head_info = {"version": qlib.__version__}
diff --git a/qlib/data/data.py b/qlib/data/data.py
index ef5e7fe8a..a4c3d63f2 100644
--- a/qlib/data/data.py
+++ b/qlib/data/data.py
@@ -41,13 +41,13 @@ class CalendarProvider(abc.ABC):
Parameters
----------
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency, available: year/quarter/month/week/day
+ time frequency, available: year/quarter/month/week/day.
future : bool
- whether including future trading day
+ whether including future trading day.
Returns
----------
@@ -62,24 +62,24 @@ class CalendarProvider(abc.ABC):
Parameters
----------
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency, available: year/quarter/month/week/day
+ time frequency, available: year/quarter/month/week/day.
future : bool
- whether including future trading day
+ whether including future trading day.
Returns
-------
pd.Timestamp
- the real start time
+ the real start time.
pd.Timestamp
- the real end time
+ the real end time.
int
- the index of start time
+ the index of start time.
int
- the index of end time
+ the index of end time.
"""
start_time = pd.Timestamp(start_time)
end_time = pd.Timestamp(end_time)
@@ -103,16 +103,16 @@ class CalendarProvider(abc.ABC):
Parameters
----------
freq : str
- frequency of read calendar file
+ frequency of read calendar file.
future : bool
- whether including future trading day
+ whether including future trading day.
Returns
-------
list
- list of timestamps
+ list of timestamps.
dict
- dict composed by timestamp as key and index as value for fast search
+ dict composed by timestamp as key and index as value for fast search.
"""
flag = f"{freq}_future_{future}"
if flag in H["c"]:
@@ -141,14 +141,14 @@ class InstrumentProvider(abc.ABC):
Parameters
----------
market : str
- market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500
+ market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500.
filter_pipe : list
- the list of dynamic filters
+ the list of dynamic filters.
Returns
----------
dict
- dict of stockpool config
+ dict of stockpool config.
{`market`=>base market name, `filter_pipe`=>list of filters}
example :
@@ -182,13 +182,13 @@ class InstrumentProvider(abc.ABC):
Parameters
----------
instruments : dict
- stockpool config
+ stockpool config.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
as_list : bool
- return instruments as list or dict
+ return instruments as list or dict.
Returns
-------
@@ -243,15 +243,15 @@ class FeatureProvider(abc.ABC):
Parameters
----------
instrument : str
- a certain instrument
+ a certain instrument.
field : str
- a certain field of feature
+ a certain field of feature.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency, available: year/quarter/month/week/day
+ time frequency, available: year/quarter/month/week/day.
Returns
-------
@@ -294,15 +294,15 @@ class ExpressionProvider(abc.ABC):
Parameters
----------
instrument : str
- a certain instrument
+ a certain instrument.
field : str
- a certain field of feature
+ a certain field of feature.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency, available: year/quarter/month/week/day
+ time frequency, available: year/quarter/month/week/day.
Returns
-------
@@ -325,20 +325,20 @@ class DatasetProvider(abc.ABC):
Parameters
----------
instruments : list or dict
- list/dict of instruments or dict of stockpool config
+ list/dict of instruments or dict of stockpool config.
fields : list
- list of feature instances
+ list of feature instances.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency
+ time frequency.
Returns
----------
pd.DataFrame
- a pandas dataframe with index
+ a pandas dataframe with index.
"""
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")
@@ -357,17 +357,17 @@ class DatasetProvider(abc.ABC):
Parameters
----------
instruments : list or dict
- list/dict of instruments or dict of stockpool config
+ list/dict of instruments or dict of stockpool config.
fields : list
- list of feature instances
+ list of feature instances.
start_time : str
- start of the time range
+ start of the time range.
end_time : str
- end of the time range
+ end of the time range.
freq : str
- time frequency
+ time frequency.
disk_cache : int
- whether to skip(0)/use(1)/replace(2) disk_cache
+ whether to skip(0)/use(1)/replace(2) disk_cache.
"""
return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache)
@@ -526,7 +526,7 @@ class LocalCalendarProvider(CalendarProvider):
Parameters
----------
freq : str
- frequency of read calendar file
+ frequency of read calendar file.
Returns
----------
diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py
index 3dbc17c23..e7d296d73 100644
--- a/qlib/data/dataset/__init__.py
+++ b/qlib/data/dataset/__init__.py
@@ -17,8 +17,8 @@ class Dataset(Serializable):
init is designed to finish following steps:
- setup data
- - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing
-
+ - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
+
- initialize the state of the dataset(info to prepare the data)
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
@@ -29,17 +29,17 @@ class Dataset(Serializable):
def setup_data(self, *args, **kwargs):
"""
- setup the data
+ Setup the data.
We split the setup_data function for following situation:
- - User have a Dataset object with learned status on disk
+ - User have a Dataset object with learned status on disk.
- - User load the Dataset object from the disk(Note the init function is skiped)
+ - User load the Dataset object from the disk(Note the init function is skiped).
- - User call `setup_data` to load new data
+ - User call `setup_data` to load new data.
- - User prepare data for model based on previous status
+ - User prepare data for model based on previous status.
"""
pass
@@ -66,9 +66,10 @@ class DatasetH(Dataset):
User should try to put the data preprocessing functions into handler.
Only following data processing functions should be placed in Dataset:
+
- The processing is related to specific model.
- - The processing is related to data split
+ - The processing is related to data split.
"""
def __init__(self, handler: Union[dict, DataHandler], segments: list):
@@ -76,15 +77,15 @@ class DatasetH(Dataset):
Parameters
----------
handler : Union[dict, DataHandler]
- handler will be passed into setup_data
+ handler will be passed into setup_data.
segments : list
- handler will be passed into setup_data
+ handler will be passed into setup_data.
"""
super().__init__(handler, segments)
def setup_data(self, handler: Union[dict, DataHandler], segments: list):
"""
- setup the underlying data
+ Setup the underlying data.
Parameters
----------
@@ -94,12 +95,13 @@ class DatasetH(Dataset):
- insntance of `DataHandler`
- config of `DataHandler`. Please refer to `DataHandler`
+
segments : list
Describe the options to segment the data.
Here are some examples:
.. code-block::
-
+
1) 'segments': {
'train': ("2008-01-01", "2014-12-31"),
'valid': ("2017-01-01", "2020-08-01",),
@@ -121,7 +123,7 @@ class DatasetH(Dataset):
**kwargs,
) -> Union[List[pd.DataFrame], pd.DataFrame]:
"""
- prepare the data for learning and inference
+ Prepare the data for learning and inference.
Parameters
----------
@@ -132,11 +134,12 @@ class DatasetH(Dataset):
- 'train'
- ['train', 'valid']
+
col_set : str
- The col_set will be passed to self._handler when fetching data
- data_key: str
+ The col_set will be passed to self._handler when fetching data.
+ data_key : str
The data to fetch: DK_*
- Default is DK_I, which indicate fetching data for **inference**
+ Default is DK_I, which indicate fetching data for **inference**.
Returns
-------
diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py
index 4d3d88c38..905fcd623 100644
--- a/qlib/data/dataset/handler.py
+++ b/qlib/data/dataset/handler.py
@@ -29,7 +29,7 @@ class DataHandler(Serializable):
"""
The steps to using a handler
1. initialized data handler (call by `init`).
- 2. use the data
+ 2. use the data.
The data handler try to maintain a handler with 2 level.
@@ -65,17 +65,17 @@ class DataHandler(Serializable):
Parameters
----------
instruments :
- The stock list to retrive
+ The stock list to retrive.
start_time :
- start_time of the original data
+ start_time of the original data.
end_time :
- end_time of the original data
+ end_time of the original data.
data_loader : Tuple[dict, str, DataLoader]
- data loader to load the data
+ data loader to load the data.
init_data :
- intialize the original data in the constructor
+ intialize the original data in the constructor.
fetch_orig : bool
- Return the original data instead of copy if possible
+ Return the original data instead of copy if possible.
"""
# Set logger
self.logger = get_module_logger("DataHandler")
@@ -219,9 +219,9 @@ class DataHandler(Serializable):
get a iterator of sliced data with given periods
Args:
- periods (int): number of periods
- min_periods (int): minimum periods for sliced dataframe
- kwargs (dict): will be passed to `self.fetch`
+ periods (int): number of periods.
+ min_periods (int): minimum periods for sliced dataframe.
+ kwargs (dict): will be passed to `self.fetch`.
"""
trading_dates = self._data.index.unique(level="datetime")
if min_periods is None:
@@ -243,10 +243,10 @@ class DataHandlerLP(DataHandler):
# process type
PTYPE_I = "independent"
- # - self._infer will processed by infer_processors
+ # - self._infer will be processed by infer_processors
# - self._learn will be processed by learn_processors
PTYPE_A = "append"
- # - self._infer will processed by infer_processors
+ # - self._infer will be processed by infer_processors
# - self._learn will be processed by infer_processors + learn_processors
# - (e.g. self._infer processed by learn_processors )
@@ -265,30 +265,40 @@ class DataHandlerLP(DataHandler):
Parameters
----------
infer_processors : list
- list of of processors to generate data for inference
- example of :
- 1) classname & kwargs:
- {
- "class": "MinMaxNorm",
- "kwargs": {
- "fit_start_time": "20080101",
- "fit_end_time": "20121231"
+ - list of of processors to generate data for inference
+
+ - example of :
+
+ .. code-block::
+
+ 1) classname & kwargs:
+ {
+ "class": "MinMaxNorm",
+ "kwargs": {
+ "fit_start_time": "20080101",
+ "fit_end_time": "20121231"
+ }
}
- }
- 2) Only classname:
- "DropnaFeature"
- 3) object instance of Processor
+ 2) Only classname:
+ "DropnaFeature"
+ 3) object instance of Processor
learn_processors : list
similar to infer_processors, but for generating data for learning models
process_type: str
PTYPE_I = 'independent'
+
- self._infer will processed by infer_processors
+
- self._learn will be processed by learn_processors
+
PTYPE_A = 'append'
+
- self._infer will processed by infer_processors
+
- self._learn will be processed by infer_processors + learn_processors
+
- (e.g. self._infer processed by learn_processors )
"""
@@ -377,7 +387,7 @@ class DataHandlerLP(DataHandler):
Parameters
----------
init_type : str
- The type `IT_*` listed above
+ The type `IT_*` listed above.
enable_cache : bool
default value is false:
@@ -419,13 +429,13 @@ class DataHandlerLP(DataHandler):
Parameters
----------
selector : Union[pd.Timestamp, slice, str]
- describe how to select data by index
+ describe how to select data by index.
level : Union[str, int]
- which index level to select the data
+ which index level to select the data.
col_set : str
- select a set of meaningful columns.(e.g. features, columns)
- data_key: str
- The data to fetch: DK_*
+ select a set of meaningful columns.(e.g. features, columns).
+ data_key : str
+ the data to fetch: DK_*.
Returns
-------
@@ -443,9 +453,9 @@ class DataHandlerLP(DataHandler):
Parameters
----------
col_set : str
- select a set of meaningful columns.(e.g. features, columns)
- data_key: str
- The data to fetch: DK_*
+ select a set of meaningful columns.(e.g. features, columns).
+ data_key : str
+ the data to fetch: DK_*.
Returns
-------
diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py
index 404313e80..a51ea119a 100644
--- a/qlib/data/dataset/loader.py
+++ b/qlib/data/dataset/loader.py
@@ -21,27 +21,11 @@ class DataLoader(abc.ABC):
@abc.abstractmethod
def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
"""
- load the data as pd.DataFrame
+ load the data as pd.DataFrame.
- Parameters
- ----------
- self : [TODO:type]
- [TODO:description]
- instruments : [TODO:type]
- [TODO:description]
- start_time : [TODO:type]
- [TODO:description]
- end_time : [TODO:type]
- [TODO:description]
+ Example of the data (The multi-index of the columns is optional.):
- Returns
- -------
- pd.DataFrame:
- data load from the under layer source
-
- Example of the data (The multi-index of the columns is optional.):
-
- .. code-block::
+ .. code-block:: python
feature label
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
@@ -49,6 +33,21 @@ class DataLoader(abc.ABC):
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
+
+
+ Parameters
+ ----------
+ instruments : str or dict
+ it can either be the market name or the config file of instruments generated by InstrumentProvider.
+ start_time : str
+ start of the time range.
+ end_time : str
+ end of the time range.
+
+ Returns
+ -------
+ pd.DataFrame:
+ data load from the under layer source
"""
pass
@@ -67,7 +66,7 @@ class DLWParser(DataLoader):
config : Tuple[list, tuple, dict]
Config will be used to describe the fields and column names
- .. code-block:: YAML
+ .. code-block::
:= {
"group_name1":
@@ -102,16 +101,16 @@ class DLWParser(DataLoader):
Parameters
----------
instruments :
- the instruments
+ the instruments.
exprs : list
- The expressions to describe the content of the data
+ the expressions to describe the content of the data.
names : list
- The name of the data
+ the name of the data.
Returns
-------
pd.DataFrame:
- the queried dataframe
+ the queried dataframe.
"""
pass
diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py
index 4a2d36e2f..fc85ccde9 100755
--- a/qlib/data/dataset/processor.py
+++ b/qlib/data/dataset/processor.py
@@ -21,7 +21,7 @@ def get_group_columns(df: pd.DataFrame, group: str):
Parameters
----------
df : pd.DataFrame
- with multi of columns
+ with multi of columns.
group : str
the name of the feature group, i.e. the first level value of the group index.
"""
@@ -56,7 +56,7 @@ class Processor(Serializable):
Parameters
----------
df : pd.DataFrame
- The raw_df of handler or result from previous processor
+ The raw_df of handler or result from previous processor.
"""
pass
@@ -68,7 +68,7 @@ class Processor(Serializable):
Returns
-------
bool:
- if it is usable for infenrece
+ if it is usable for infenrece.
"""
return True
@@ -176,7 +176,9 @@ class MinMaxNorm(Processor):
return df
-class ZscoreNorm(Processor):
+class ZScoreNorm(Processor):
+ """ZScore Normalization"""
+
def __init__(self, fit_start_time, fit_end_time, fields_group=None):
self.fit_start_time = fit_start_time
self.fit_end_time = fit_end_time
@@ -203,6 +205,42 @@ class ZscoreNorm(Processor):
return df
+class RobustZScoreNorm(Processor):
+ """Robust ZScore Normalization
+
+ Use robust statistics for Z-Score normalization:
+ mean(x) = median(x)
+ std(x) = MAD(x) * 1.4826
+
+ Reference:
+ https://en.wikipedia.org/wiki/Median_absolute_deviation.
+ """
+
+ def __init__(self, fit_start_time, fit_end_time, fields_group=None, clip_outlier=True):
+ self.fit_start_time = fit_start_time
+ self.fit_end_time = fit_end_time
+ self.fields_group = fields_group
+ self.clip_outlier = clip_outlier
+
+ def fit(self, df):
+ df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
+ self.cols = get_group_columns(df, self.fields_group)
+ X = df[self.cols].values
+ self.mean_train = np.nanmedian(X, axis=0)
+ self.std_train = np.nanmedian(np.abs(X - self.mean_train), axis=0)
+ self.std_train += EPS
+ self.std_train *= 1.4826
+
+ def __call__(self, df):
+ X = df[self.cols]
+ X -= self.mean_train
+ X /= self.std_train
+ df[self.cols] = X
+ if self.clip_outlier:
+ df.clip(-3, 3, inplace=True)
+ return df
+
+
class CSZScoreNorm(Processor):
"""Cross Sectional ZScore Normalization"""
diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py
index 3fb3768a0..feda19044 100644
--- a/qlib/data/dataset/utils.py
+++ b/qlib/data/dataset/utils.py
@@ -51,6 +51,9 @@ def fetch_df_by_index(
-------
Data of the given index.
"""
+ # level = None -> use selector directly
+ if level == None:
+ return df.loc(axis=0)[selector]
# Try to get the right index
idx_slc = (selector, slice(None, None))
if get_level_index(df, level) == 1:
diff --git a/qlib/data/filter.py b/qlib/data/filter.py
index 47b093b67..70f9d3278 100644
--- a/qlib/data/filter.py
+++ b/qlib/data/filter.py
@@ -32,7 +32,7 @@ class BaseDFilter(abc.ABC):
Parameters
----------
config : dict
- dict of config parameters
+ dict of config parameters.
"""
raise NotImplementedError("Subclass of BaseDFilter must reimplement `from_config` method")
@@ -43,7 +43,7 @@ class BaseDFilter(abc.ABC):
Returns
----------
dict
- return the dict of config parameters
+ return the dict of config parameters.
"""
raise NotImplementedError("Subclass of BaseDFilter must reimplement `to_config` method")
@@ -69,9 +69,9 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
fstart_time: str
- the time for the filter rule to start filter the instruments
+ the time for the filter rule to start filter the instruments.
fend_time: str
- the time for the filter rule to stop filter the instruments
+ the time for the filter rule to stop filter the instruments.
"""
super(SeriesDFilter, self).__init__()
self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None
@@ -83,12 +83,12 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
instruments: dict
- the dict of instruments in the form {instrument_name => list of timestamp tuple}
+ the dict of instruments in the form {instrument_name => list of timestamp tuple}.
Returns
----------
pd.Timestamp, pd.Timestamp
- the lower time bound and upper time bound of all the instruments
+ the lower time bound and upper time bound of all the instruments.
"""
trange = Cal.calendar(freq=self.filter_freq)
ubound, lbound = trange[0], trange[-1]
@@ -105,14 +105,14 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
time_range : D.calendar
- the time range of the instruments
+ the time range of the instruments.
target_timestamp : list
- the list of tuple (timestamp, timestamp)
+ the list of tuple (timestamp, timestamp).
Returns
----------
pd.Series
- the series of bool value for an instrument
+ the series of bool value for an instrument.
"""
# Construct a whole dict of {date => bool}
timestamp_series = {timestamp: False for timestamp in time_range}
@@ -124,19 +124,19 @@ class SeriesDFilter(BaseDFilter):
return timestamp_series
def _filterSeries(self, timestamp_series, filter_series):
- """Filter the timestamp series with filter series by using element-wise AND operation of the two series
+ """Filter the timestamp series with filter series by using element-wise AND operation of the two series.
Parameters
----------
timestamp_series : pd.Series
- the series of bool value indicating existing time
+ the series of bool value indicating existing time.
filter_series : pd.Series
- the series of bool value indicating filter feature
+ the series of bool value indicating filter feature.
Returns
----------
pd.Series
- the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp
+ the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp.
"""
fstart, fend = list(filter_series.keys())[0], list(filter_series.keys())[-1]
filter_series = filter_series.astype("bool") # Make sure the filter_series is boolean
@@ -144,17 +144,17 @@ class SeriesDFilter(BaseDFilter):
return timestamp_series
def _toTimestamp(self, timestamp_series):
- """Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE
+ """Convert the timestamp series to a list of tuple (timestamp, timestamp) indicating a continuous range of TRUE.
Parameters
----------
timestamp_series: pd.Series
- the series of bool value after being filtered
+ the series of bool value after being filtered.
Returns
----------
list
- the list of tuple (timestamp, timestamp)
+ the list of tuple (timestamp, timestamp).
"""
# sort the timestamp_series according to the timestamps
timestamp_series.sort_index()
@@ -194,18 +194,18 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
instruments : dict
- the dict of instruments to be filtered
+ the dict of instruments to be filtered.
fstart : pd.Timestamp
- start time of filter
+ start time of filter.
fend : pd.Timestamp
- end time of filter
+ end time of filter.
- .. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time
+ .. note:: fstart/fend indicates the intersection of instruments start/end time and filter start/end time.
Returns
----------
pd.Dataframe
- a series of {pd.Timestamp => bool}
+ a series of {pd.Timestamp => bool}.
"""
raise NotImplementedError("Subclass of SeriesDFilter must reimplement `getFilterSeries` method")
@@ -215,16 +215,16 @@ class SeriesDFilter(BaseDFilter):
Parameters
----------
instruments: dict
- input instruments to be filtered
+ input instruments to be filtered.
start_time: str
- start of the time range
+ start of the time range.
end_time: str
- end of the time range
+ end of the time range.
Returns
----------
dict
- filtered instruments, same structure as input instruments
+ filtered instruments, same structure as input instruments.
"""
lbound, ubound = self._getTimeBound(instruments)
start_time = pd.Timestamp(start_time or lbound)
@@ -272,7 +272,7 @@ class NameDFilter(SeriesDFilter):
params:
------
name_rule_re: str
- regular expression for the name rule
+ regular expression for the name rule.
"""
super(NameDFilter, self).__init__(fstart_time, fend_time)
self.name_rule_re = name_rule_re
@@ -325,13 +325,13 @@ class ExpressionDFilter(SeriesDFilter):
params:
------
fstart_time: str
- filter the feature starting from this time
+ filter the feature starting from this time.
fend_time: str
- filter the feature ending by this time
+ filter the feature ending by this time.
rule_expression: str
- an input expression for the rule
+ an input expression for the rule.
keep: bool
- whether to keep the instruments of which features don't exist in the filter time span
+ whether to keep the instruments of which features don't exist in the filter time span.
"""
super(ExpressionDFilter, self).__init__(fstart_time, fend_time)
self.rule_expression = rule_expression
diff --git a/qlib/data/ops.py b/qlib/data/ops.py
index 179193167..e17c0e4e6 100644
--- a/qlib/data/ops.py
+++ b/qlib/data/ops.py
@@ -18,9 +18,8 @@ try:
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
from ._libs.expanding import expanding_slope, expanding_rsquare, expanding_resi
except ImportError as err:
- print(err)
- print("Do not import qlib package in the repository directory")
- sys.exit(-1)
+ print("Do not import qlib package in the repository directory!")
+ raise
__all__ = (
"Ref",
@@ -865,6 +864,8 @@ class Skew(Rolling):
"""
def __init__(self, feature, N):
+ if N != 0 and N < 3:
+ raise ValueError("The rolling window size of Skewness operation should >= 3")
super(Skew, self).__init__(feature, N, "skew")
@@ -885,6 +886,8 @@ class Kurt(Rolling):
"""
def __init__(self, feature, N):
+ if N != 0 and N < 4:
+ raise ValueError("The rolling window size of Kurtosis operation should >= 5")
super(Kurt, self).__init__(feature, N, "kurt")
@@ -1268,7 +1271,7 @@ class WMA(Rolling):
def weighted_mean(x):
w = np.arange(len(x))
- w /= w.sum()
+ w = w / w.sum()
return np.nanmean(w * x)
if self.N == 0:
diff --git a/qlib/model/base.py b/qlib/model/base.py
index d6ee50e33..fd220cd7e 100644
--- a/qlib/model/base.py
+++ b/qlib/model/base.py
@@ -33,7 +33,7 @@ class Model(BaseModel):
Parameters
----------
dataset : Dataset
- dataset will generate the processed data from model training
+ dataset will generate the processed data from model training.
"""
raise NotImplementedError()
@@ -44,7 +44,7 @@ class Model(BaseModel):
Parameters
----------
dataset : Dataset
- dataset will generate the processed dataset from model training
+ dataset will generate the processed dataset from model training.
"""
raise NotImplementedError()
@@ -59,6 +59,6 @@ class ModelFT(Model):
Parameters
----------
dataset : Dataset
- dataset will generate the processed dataset from model training
+ dataset will generate the processed dataset from model training.
"""
raise NotImplementedError()
diff --git a/qlib/model/riskmodel.py b/qlib/model/riskmodel.py
index b5275213b..07a1e0c9f 100644
--- a/qlib/model/riskmodel.py
+++ b/qlib/model/riskmodel.py
@@ -23,9 +23,9 @@ class RiskModel(BaseModel):
def __init__(self, nan_option: str = "ignore", assume_centered: bool = False, scale_return: bool = True):
"""
Args:
- nan_option (str): nan handling option (`ignore`/`mask`/`fill`)
- assume_centered (bool): whether the data is assumed to be centered
- scale_return (bool): whether scale returns as percentage
+ nan_option (str): nan handling option (`ignore`/`mask`/`fill`).
+ assume_centered (bool): whether the data is assumed to be centered.
+ scale_return (bool): whether scale returns as percentage.
"""
# nan
assert nan_option in [
@@ -45,11 +45,11 @@ class RiskModel(BaseModel):
Args:
X (pd.Series, pd.DataFrame or np.ndarray): data from which to estimate the covariance,
with variables as columns and observations as rows.
- return_corr (bool): whether return the correlation matrix
- is_price (bool): whether `X` contains price (if not assume stock returns)
+ return_corr (bool): whether return the correlation matrix.
+ is_price (bool): whether `X` contains price (if not assume stock returns).
Returns:
- pd.DataFrame or np.ndarray: estimated covariance (or correlation)
+ pd.DataFrame or np.ndarray: estimated covariance (or correlation).
"""
# transform input into 2D array
if not isinstance(X, (pd.Series, pd.DataFrame)):
@@ -101,10 +101,10 @@ class RiskModel(BaseModel):
By default, this method implements the empirical covariance estimation.
Args:
- X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows)
+ X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).
Returns:
- np.ndarray: covariance matrix
+ np.ndarray: covariance matrix.
"""
xTx = np.asarray(X.T.dot(X))
N = len(X)
@@ -117,7 +117,7 @@ class RiskModel(BaseModel):
"""handle nan and centerize data
Note:
- if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`
+ if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`.
"""
# handle nan
if self.nan_option == self.FILL_NAN:
@@ -139,15 +139,15 @@ class ShrinkCovEstimator(RiskModel):
where `alpha` is the shrink parameter and `F` is the shrinking target.
The following shrinking parameters (`alpha`) are supported:
- - `lw` [1][2][3]: use Ledoit-Wolf shrinking parameter
- - `oas` [4]: use Oracle Approximating Shrinkage shrinking parameter
- - float: directly specify the shrink parameter, should be between [0, 1]
+ - `lw` [1][2][3]: use Ledoit-Wolf shrinking parameter.
+ - `oas` [4]: use Oracle Approximating Shrinkage shrinking parameter.
+ - float: directly specify the shrink parameter, should be between [0, 1].
The following shrinking targets (`F`) are supported:
- - `const_var` [1][4][5]: assume stocks have the same constant variance and zero correlation
- - `const_corr` [2][6]: assume stocks have different variance but equal correlation
- - `single_factor` [3][7]: assume single factor model as the shrinking target
- - np.ndarray: provide the shrinking targets directly
+ - `const_var` [1][4][5]: assume stocks have the same constant variance and zero correlation.
+ - `const_corr` [2][6]: assume stocks have different variance but equal correlation.
+ - `single_factor` [3][7]: assume single factor model as the shrinking target.
+ - np.ndarray: provide the shrinking targets directly.
Note:
- The optimal shrinking parameter depends on the selection of the shrinking target.
@@ -402,13 +402,13 @@ class POETCovEstimator(RiskModel):
def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = "soft", **kwargs):
"""
Args:
- num_factors (int): number of factors (if set to zero, no factor model will be used)
- thresh (float): the positive constant for thresholding
+ num_factors (int): number of factors (if set to zero, no factor model will be used).
+ thresh (float): the positive constant for thresholding.
thresh_method (str): thresholding method, which can be
- - 'soft': soft thresholding
- - 'hard': hard thresholding
- - 'scad': scad thresholding
- kwargs: see `RiskModel` for more information
+ - 'soft': soft thresholding.
+ - 'hard': hard thresholding.
+ - 'scad': scad thresholding.
+ kwargs: see `RiskModel` for more information.
"""
super().__init__(**kwargs)
diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py
new file mode 100644
index 000000000..e4fc8eef9
--- /dev/null
+++ b/qlib/model/trainer.py
@@ -0,0 +1,40 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+from qlib.utils import init_instance_by_config, flatten_dict
+from qlib.workflow import R
+from qlib.workflow.record_temp import SignalRecord
+
+
+def task_train(config: dict, experiment_name):
+ """
+ task based training
+
+ Parameters
+ ----------
+ config : dict
+ A dict describing the training process
+ """
+
+ # model initiaiton
+ model = init_instance_by_config(config.get("task")["model"])
+ dataset = init_instance_by_config(config.get("task")["dataset"])
+
+ # start exp
+ with R.start(experiment_name=experiment_name):
+ # train model
+ R.log_params(**flatten_dict(config.get("task")))
+ model.fit(dataset)
+ recorder = R.get_recorder()
+
+ # generate records: prediction, backtest, and analysis
+ for record in config.get("task")["record"]:
+ if record["class"] == SignalRecord.__name__:
+ srconf = {"model": model, "dataset": dataset, "recorder": recorder}
+ record["kwargs"].update(srconf)
+ sr = init_instance_by_config(record)
+ sr.generate()
+ else:
+ rconf = {"recorder": recorder}
+ record["kwargs"].update(rconf)
+ ar = init_instance_by_config(record)
+ ar.generate()
diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py
index 8944ecbe6..c0745f6d4 100644
--- a/qlib/workflow/__init__.py
+++ b/qlib/workflow/__init__.py
@@ -10,22 +10,6 @@ from ..utils import Wrapper
class QlibRecorder:
"""
A global system that helps to manage the experiments.
-
- The components of the system:
- 1) ExperimentManager: a class managing experiments.
- 2) Experiment: a class of experiment, and each instance of it is responsible for a single experiment.
- 3) Recorder: a class of recorder, and each instance of it is responsible for a single run.
-
- The general structure of the system:
- ExperimentManager
- - Experiment 1
- - Recorder 1
- - Recorder 2
- - ...
- - Experiment 2
- - ...
- - ...
-
"""
def __init__(self, exp_manager):
@@ -34,16 +18,14 @@ class QlibRecorder:
@contextmanager
def start(self, experiment_name=None, recorder_name=None):
"""
- Method to start an experiment. This method can only be called within a Python's `with` statement.
+ Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:
- Use case:
- ---------
- ```
- with R.start('test', 'recorder_1'):
- model.fit(dataset)
- R.log...
- ... # further operations
- ```
+ .. code-block:: Python
+
+ with R.start('test', 'recorder_1'):
+ model.fit(dataset)
+ R.log...
+ ... # further operations
Parameters
----------
@@ -63,15 +45,14 @@ class QlibRecorder:
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
"""
Lower level method for starting an experiment. When use this method, one should end the experiment manually
- and the status of the recorder may not be handled properly.
+ and the status of the recorder may not be handled properly. Here is the example code:
+
+ .. code-block:: Python
+
+ R.start_exp(experiment_name='test', recorder_name='recorder_1')
+ ... # further operations
+ R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
- Use case:
- ---------
- ```
- R.start_exp(experiment_name='test', recorder_name='recorder_1')
- ... # further operations
- R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
- ```
Parameters
----------
@@ -92,15 +73,13 @@ class QlibRecorder:
def end_exp(self, recorder_status=Recorder.STATUS_FI):
"""
Method for ending an experiment manually. It will end the current active experiment, as well as its
- active recorder with the specified `status` type.
+ active recorder with the specified `status` type. Here is the example code of the method:
- Use case:
- ---------
- ```
- R.start_exp(experiment_name='test')
- ... # further operations
- R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
- ```
+ .. code-block:: Python
+
+ R.start_exp(experiment_name='test')
+ ... # further operations
+ R.end_exp('FINISHED') or R.end_exp(Recorder.STATUS_S)
Parameters
----------
@@ -111,14 +90,12 @@ class QlibRecorder:
def search_records(self, experiment_ids, **kwargs):
"""
- Get a pandas DataFrame of records that fit the search criteria.
+ Get a pandas DataFrame of records that fit the search criteria. Here is the example code of the method:
- Use case:
- ---------
- ```
- R.log_metrics(m=2.50, step=0)
- records = R.search_runs([experiment_id], order_by=["metrics.m DESC"])
- ```
+ .. code-block:: Python
+
+ R.log_metrics(m=2.50, step=0)
+ records = R.search_runs([experiment_id], order_by=["metrics.m DESC"])
Parameters
----------
@@ -146,11 +123,9 @@ class QlibRecorder:
"""
Method for listing all the existing experiments (except for those being deleted.)
- Use case:
- ---------
- ```
- exps = R.list_experiments()
- ```
+ .. code-block:: Python
+
+ exps = R.list_experiments()
Returns
-------
@@ -166,11 +141,11 @@ class QlibRecorder:
list all the recorders of the default experiment. If the default experiment doesn't exist, the method will first
create the default experiment, and then create a new recorder under it.
- Use case:
- ---------
- ```
- recorders = R.list_recorders(experiment_name='test')
- ```
+ Here is the example code:
+
+ .. code-block:: Python
+
+ recorders = R.list_recorders(experiment_name='test')
Parameters
----------
@@ -191,46 +166,55 @@ class QlibRecorder:
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
only retrieve a specific experiment or raise an Error.
- If `create` is True:
- If R's running:
- 1) no id or name specified, return the active experiment.
- 2) if id or name is specified, return the specified experiment. If no such exp found,
- create a new experiment with given id or name, and the experiment is set to be running.
- If R's not running:
- 1) no id or name specified, create a default experiment, and the experiment is set to be running.
- 2) if id or name is specified, return the specified experiment. If no such exp found,
- create a new experiment with given name or the default experiment, and the experiment is set to be running.
- Else If `create` is False:
- If R's running:
- 1) no id or name specified, return the active experiment.
- 2) if id or name is specified, return the specified experiment. If no such exp found,
- raise Error.
- If R's not running:
- 1) no id or name specified. If the default experiment exists, return it, otherwise, raise Error.
- 2) if id or name is specified, return the specified experiment. If no such exp found,
- raise Error.
+ - If '`create`' is True:
- Use case:
- ---------
- ```
- # Case 1
- with R.start('test'):
- exp = R.get_exp()
- recorders = exp.list_recorders()
+ - If ``R``'s running:
- # Case 2
- with R.start('test'):
- exp = R.get_exp('test1')
+ - no id or name specified, return the active experiment.
- # Case 3
- exp = R.get_exp() -> a default experiment.
+ - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be running.
- # Case 4
- exp = R.get_exp(experiment_name='test')
+ - If ``R``'s not running:
- # Case 5
- exp = R.get_exp(create=False) -> the default experiment if exists.
- ```
+ - no id or name specified, create a default experiment, and the experiment is set to be running.
+
+ - if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment, and the experiment is set to be running.
+
+ - Else If '`create`' is False:
+
+ - If ``R``'s running:
+
+ - no id or name specified, return the active experiment.
+
+ - if id or name is specified, return the specified experiment. If no such exp found, raise Error.
+
+ - If ``R``'s not running:
+
+ - no id or name specified. If the default experiment exists, return it, otherwise, raise Error.
+
+ - if id or name is specified, return the specified experiment. If no such exp found, raise Error.
+
+ Here are some use cases:
+
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ exp = R.get_exp()
+ recorders = exp.list_recorders()
+
+ # Case 2
+ with R.start('test'):
+ exp = R.get_exp('test1')
+
+ # Case 3
+ exp = R.get_exp() -> a default experiment.
+
+ # Case 4
+ exp = R.get_exp(experiment_name='test')
+
+ # Case 5
+ exp = R.get_exp(create=False) -> the default experiment if exists.
Parameters
----------
@@ -253,11 +237,11 @@ class QlibRecorder:
Method for deleting the experiment with given id or name. At least one of id or name must be given,
otherwise, error will occur.
- Use case:
- ---------
- ```
- R.delete_exp(experiment_name='test')
- ```
+ Here is the example code:
+
+ .. code-block:: Python
+
+ R.delete_exp(experiment_name='test')
Parameters
----------
@@ -272,11 +256,11 @@ class QlibRecorder:
"""
Method for retrieving the uri of current experiment manager.
- Use case:
- ---------
- ```
- uri = R.get_uri()
- ```
+ Here is the example code:
+
+ .. code-block:: Python
+
+ uri = R.get_uri()
Returns
-------
@@ -288,35 +272,41 @@ class QlibRecorder:
"""
Method for retrieving a recorder.
- If R's running: 1) no id or name specified, return the active recorder. 2) if id or name is
- specified, return the specified recorder.
- If R's not running: 1) no id or name specified, raise Error. 2) if id or name is specified,
- and the corresponding experiment_name must be given, return the specified recorder. Otherwise,
- raise Error.
+ - If ``R``'s running:
+
+ - no id or name specified, return the active recorder.
+
+ - if id or name is specified, return the specified recorder.
+
+ - If ``R``'s not running:
+
+ - no id or name specified, raise Error.
+
+ - if id or name is specified, and the corresponding experiment_name must be given, return the specified recorder. Otherwise, raise Error.
The recorder can be used for further process such as `save_object`, `load_object`, `log_params`,
`log_metrics`, etc.
- Use case:
- ---------
- ```
- # Case 1
- with R.start('test'):
- recorder = R.get_recorder()
+ Here are some use cases:
- # Case 2
- with R.start('test'):
- recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
+ .. code-block:: Python
- # Case 3
- recorder = R.get_recorder() -> Error
+ # Case 1
+ with R.start('test'):
+ recorder = R.get_recorder()
- # Case 4
- recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d') -> Error
+ # Case 2
+ with R.start('test'):
+ recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
- # Case 5
- recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')
- ```
+ # Case 3
+ recorder = R.get_recorder() -> Error
+
+ # Case 4
+ recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d') -> Error
+
+ # Case 5
+ recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')
Parameters
----------
@@ -340,11 +330,11 @@ class QlibRecorder:
Method for deleting the recorders with given id or name. At least one of id or name must be given,
otherwise, error will occur.
- Use case:
- ---------
- ```
- R.delete_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
- ```
+ Here is the example code:
+
+ .. code-block:: Python
+
+ R.delete_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d')
Parameters
----------
@@ -361,26 +351,25 @@ class QlibRecorder:
from a local file/directory, or directly saving objects. User can use valid python's keywords arguments
to specify the object to be saved as well as its name (name: value).
- If R's running: it will save the objects through the running recorder.
- If R's not running: the system will create a default experiment, and a new recorder and
- save objects under it.
+ - If R's running: it will save the objects through the running recorder.
+ - If R's not running: the system will create a default experiment, and a new recorder and save objects under it.
- If one wants to save objects with a specific recorder. It is recommended to first
- get the specific recorder through `get_recorder` API and use the recorder the save objects.
- The supported arguments are the same as this method.
+ .. note::
- Use case:
- ---------
- ```
- # Case 1
- with R.start('test'):
- pred = model.predict(dataset)
- R.save_objects(**{"pred.pkl": pred}, artifact_path='prediction')
+ If one wants to save objects with a specific recorder. It is recommended to first get the specific recorder through `get_recorder` API and use the recorder the save objects. The supported arguments are the same as this method.
- # Case 2
- with R.start('test'):
- R.save_objects(local_path='results/pred.pkl')
- ```
+ Here are some use cases:
+
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ pred = model.predict(dataset)
+ R.save_objects(**{"pred.pkl": pred}, artifact_path='prediction')
+
+ # Case 2
+ with R.start('test'):
+ R.save_objects(local_path='results/pred.pkl')
Parameters
----------
@@ -393,25 +382,22 @@ class QlibRecorder:
def log_params(self, **kwargs):
"""
- Method for logging parameters during an experiment.
+ Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
- If R's running: it will log parameters through the running recorder.
- If R's not running: the system will create a default experiment as well as a new recorder, and
- log parameters under it.
+ - If R's running: it will log parameters through the running recorder.
+ - If R's not running: the system will create a default experiment as well as a new recorder, and log parameters under it.
- One can also log to a specific recorder after getting it with `get_recorder` API.
+ Here are some use cases:
- Use case:
- ---------
- ```
- # Case 1
- with R.start('test'):
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ R.log_params(learning_rate=0.01)
+
+ # Case 2
R.log_params(learning_rate=0.01)
- # Case 2
- R.log_params(learning_rate=0.01)
- ```
-
Parameters
----------
keyword argument:
@@ -421,25 +407,22 @@ class QlibRecorder:
def log_metrics(self, step=None, **kwargs):
"""
- Method for logging metrics during an experiment.
+ Method for logging metrics during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
- If R's running: it will log metrics through the running recorder.
- If R's not running: the system will create a default experiment as well as a new recorder, and
- log metrics under it.
+ - If R's running: it will log metrics through the running recorder.
+ - If R's not running: the system will create a default experiment as well as a new recorder, and log metrics under it.
- One can also log to a specific recorder after getting it with `get_recorder` API.
+ Here are some use cases:
- Use case:
- ---------
- ```
- # Case 1
- with R.start('test'):
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ R.log_metrics(train_loss=0.33, step=1)
+
+ # Case 2
R.log_metrics(train_loss=0.33, step=1)
- # Case 2
- R.log_metrics(train_loss=0.33, step=1)
- ```
-
Parameters
----------
keyword argument:
@@ -449,25 +432,22 @@ class QlibRecorder:
def set_tags(self, **kwargs):
"""
- Method for setting tags for a recorder.
+ Method for setting tags for a recorder. In addition to using ``R``, one can also set the tag to a specific recorder after getting it with `get_recorder` API.
- If R's running: it will set tags through the running recorder.
- If R's not running: the system will create a default experiment as well as a new recorder, and
- set the tags under it.
+ - If R's running: it will set tags through the running recorder.
+ - If R's not running: the system will create a default experiment as well as a new recorder, and set the tags under it.
- One can also set the tag to a specific recorder after getting it with `get_recorder` API.
+ Here are some use cases:
- Use case:
- ---------
- ```
- # Case 1
- with R.start('test'):
+ .. code-block:: Python
+
+ # Case 1
+ with R.start('test'):
+ R.set_tags(release_version="2.2.0")
+
+ # Case 2
R.set_tags(release_version="2.2.0")
- # Case 2
- R.set_tags(release_version="2.2.0")
- ```
-
Parameters
----------
keyword argument:
diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py
index a946af9a7..08c13de2a 100644
--- a/qlib/workflow/cli.py
+++ b/qlib/workflow/cli.py
@@ -8,9 +8,36 @@ import qlib
import fire
import pandas as pd
import ruamel.yaml as yaml
-from qlib.utils import init_instance_by_config, flatten_dict
-from qlib.workflow import R
-from qlib.workflow.record_temp import SignalRecord
+from qlib.model.trainer import task_train
+
+
+def get_path_list(path):
+ if isinstance(path, str):
+ return [path]
+ else:
+ return [p for p in path]
+
+
+def sys_config(config, config_path):
+ """
+ Configure the `sys` section
+
+ Parameters
+ ----------
+ config : dict
+ configuration of the workflow.
+ config_path : str
+ path of the configuration
+ """
+ sys_config = config.get("sys", {})
+
+ # abspath
+ for p in get_path_list(sys_config.get("path", [])):
+ sys.path.append(p)
+
+ # relative path to config path
+ for p in get_path_list(sys_config.get("rel_path", [])):
+ sys.path.append(str(Path(config_path).parent.resolve().absolute() / p))
# worflow handler function
@@ -18,33 +45,14 @@ def workflow(config_path, experiment_name="workflow"):
with open(config_path) as fp:
config = yaml.load(fp, Loader=yaml.Loader)
+ # config the `sys` section
+ sys_config(config, config_path)
+
provider_uri = config.get("provider_uri")
region = config.get("region")
qlib.init(provider_uri=provider_uri, region=region)
- # model initiaiton
- model = init_instance_by_config(config.get("task")["model"])
- dataset = init_instance_by_config(config.get("task")["dataset"])
-
- # start exp
- with R.start(experiment_name=experiment_name):
- # train model
- R.log_params(**flatten_dict(config.get("task")))
- model.fit(dataset)
- recorder = R.get_recorder()
-
- # generate records: prediction, backtest, and analysis
- for record in config.get("task")["record"]:
- if record["class"] == SignalRecord.__name__:
- srconf = {"model": model, "dataset": dataset, "recorder": recorder}
- record["kwargs"].update(srconf)
- sr = init_instance_by_config(record)
- sr.generate()
- else:
- rconf = {"recorder": recorder}
- record["kwargs"].update(rconf)
- ar = init_instance_by_config(record)
- ar.generate()
+ task_train(config, experiment_name=experiment_name)
# function to run worklflow by config
diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py
index 1d0811d16..ec76343bd 100644
--- a/qlib/workflow/record_temp.py
+++ b/qlib/workflow/record_temp.py
@@ -15,6 +15,7 @@ from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
+from ..contrib.strategy.strategy import BaseStrategy
logger = get_module_logger("workflow", "INFO")
@@ -220,7 +221,7 @@ class PortAnaRecord(SignalRecord):
self.strategy_config = config["strategy"]
self.backtest_config = config["backtest"]
- self.strategy = init_instance_by_config(self.strategy_config)
+ self.strategy = init_instance_by_config(self.strategy_config, accept_types=BaseStrategy)
def generate(self, **kwargs):
# check previously stored prediction results
diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py
index 9f6dd88e2..bdc227029 100644
--- a/scripts/dump_bin.py
+++ b/scripts/dump_bin.py
@@ -333,7 +333,9 @@ class DumpDataFix(DumpDataAll):
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
p_bar.update()
- self.save_instruments(pd.DataFrame.from_dict(self._old_instruments, orient="index"))
+ _inst_df = pd.DataFrame.from_dict(self._old_instruments, orient="index")
+ _inst_df.index.names = [self.symbol_field_name]
+ self.save_instruments(_inst_df.reset_index())
logger.info("end of instruments dump.\n")
def dump(self):
diff --git a/setup.py b/setup.py
index 4fe410b9d..3438781b2 100644
--- a/setup.py
+++ b/setup.py
@@ -43,7 +43,7 @@ REQUIRED = [
"schedule>=0.6.0",
"cvxpy==1.0.21",
"hyperopt==0.1.1",
- "fire>=0.2.1",
+ "fire>=0.3.1",
"statsmodels",
"xlrd>=1.0.0",
"plotly==4.12.0",
@@ -58,7 +58,6 @@ REQUIRED = [
"joblib>=0.17.0",
"fire>=0.3.1",
"ruamel.yaml>=0.16.12",
- "pytorch-tabnet>=2.0.1",
]
# Numpy include