mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-17 11:18:24 +08:00
Compare commits
2 Commits
neutrader
...
backtest_i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ac9dd7221 | ||
|
|
7efec6bbc4 |
64
README.md
64
README.md
@@ -11,11 +11,6 @@
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Release Qlib v0.8.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.8.0) on Dec 8, 2021 |
|
||||
| ADD model | [Released](https://github.com/microsoft/qlib/pull/704) on Nov 22, 2021 |
|
||||
| ADARNN model | [Released](https://github.com/microsoft/qlib/pull/689) on Nov 14, 2021 |
|
||||
| TCN model | [Released](https://github.com/microsoft/qlib/pull/668) on Nov 4, 2021 |
|
||||
| Nested Decision Framework | [Released](https://github.com/microsoft/qlib/pull/438) on Oct 1, 2021. [Example](https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py) and [Doc](https://qlib.readthedocs.io/en/latest/component/highfreq.html) |
|
||||
|Temporal Routing Adaptor (TRA) | [Released](https://github.com/microsoft/qlib/pull/531) on July 30, 2021 |
|
||||
| Transformer & Localformer | [Released](https://github.com/microsoft/qlib/pull/508) on July 22, 2021 |
|
||||
| Release Qlib v0.7.0 | [Released](https://github.com/microsoft/qlib/releases/tag/v0.7.0) on July 12, 2021 |
|
||||
@@ -69,6 +64,7 @@ Your feedbacks about the features are very important.
|
||||
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
|
||||
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
|
||||
| Meta-Learning-based data selection | Initial opensource version under development |
|
||||
|
||||
# Framework of Qlib
|
||||
@@ -83,7 +79,7 @@ At the module level, Qlib is a platform that consists of the above components. T
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Decision Generator` will generate the target trading decisions(i.e. portfolio, orders) to be executed by `Execution Env` (i.e. the trading market). There may be multiple levels of `Trading Agent` and `Execution Env` (e.g. an _order executor trading agent and intraday order execution environment_ could behave like an interday trading environment and nested in _daily portfolio management trading agent and interday trading environment_ ) |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
* The modules with hand-drawn style are under development and will be released in the future.
|
||||
@@ -160,17 +156,15 @@ Load and prepare data by running the following code:
|
||||
|
||||
This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in
|
||||
the same repository.
|
||||
Users could create the same dataset with it. [Description of dataset](https://github.com/microsoft/qlib/tree/main/scripts/data_collector#description-of-dataset)
|
||||
Users could create the same dataset with it.
|
||||
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
### Automatic update of daily frequency data (from yahoo finance)
|
||||
> This step is *Optional* if users only want to try their models and strategies on history data.
|
||||
>
|
||||
> It is recommended that users update the data manually once (--trading_date 2021-05-25) and then set it to update automatically.
|
||||
>
|
||||
> For more information, please refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
|
||||
> For more information refer to: [yahoo collector](https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance)
|
||||
|
||||
* Automatic update of data to the "qlib" directory each trading day(Linux)
|
||||
* use *crontab*: `crontab -e`
|
||||
@@ -284,25 +278,22 @@ The automatic workflow may not suit the research workflow of all Quant researche
|
||||
# [Quant Model (Paper) Zoo](examples/benchmarks)
|
||||
|
||||
Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](examples/benchmarks/XGBoost/)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](examples/benchmarks/LightGBM/)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](examples/benchmarks/CatBoost/)
|
||||
- [MLP based on pytorch](examples/benchmarks/MLP/)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural computation 1997)](examples/benchmarks/LSTM/)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](examples/benchmarks/GRU/)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](examples/benchmarks/ALSTM)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](examples/benchmarks/GATs/)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](examples/benchmarks/SFM/)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](examples/benchmarks/TabNet/)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](examples/benchmarks/DoubleEnsemble/)
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](examples/benchmarks/TCTS/)
|
||||
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](examples/benchmarks/Transformer/)
|
||||
- [Localformer based on pytorch (Juyong Jiang, et al.)](examples/benchmarks/Localformer/)
|
||||
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](examples/benchmarks/TRA/)
|
||||
- [TCN based on pytorch (Shaojie Bai, et al. 2018)](examples/benchmarks/TCN/)
|
||||
- [ADARNN based on pytorch (YunTao Du, et al. 2021)](examples/benchmarks/ADARNN/)
|
||||
- [ADD based on pytorch (Hongshun Tang, et al.2020)](examples/benchmarks/ADD/)
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al. KDD 2016)](qlib/contrib/model/xgboost.py)
|
||||
- [GBDT based on LightGBM (Guolin Ke, et al. NIPS 2017)](qlib/contrib/model/gbdt.py)
|
||||
- [GBDT based on Catboost (Liudmila Prokhorenkova, et al. NIPS 2018)](qlib/contrib/model/catboost_model.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al. Neural omputation 1997)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al. 2014)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al. IJCAI 2017)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al. 2017)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. KDD 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. International Journal of Forecasting 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. AAAI 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. ICDM 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
- [TCTS based on pytorch (Xueqing Wu, et al. ICML 2021)](qlib/contrib/model/pytorch_tcts.py)
|
||||
- [Transformer based on pytorch (Ashish Vaswani, et al. NeurIPS 2017)](qlib/contrib/model/pytorch_transformer.py)
|
||||
- [Localformer based on pytorch (Juyong Jiang, et al.)](qlib/contrib/model/pytorch_localformer.py)
|
||||
- [TRA based on pytorch (Hengxu, Dong, et al. KDD 2021)](qlib/contrib/model/pytorch_tra.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -400,26 +391,17 @@ Join IM discussion groups:
|
||||
||
|
||||
|
||||
# Contributing
|
||||
We appreciate all contributions and thank all the contributors!
|
||||
<a href="https://github.com/microsoft/qlib/graphs/contributors"><img src="https://contrib.rocks/image?repo=microsoft/qlib" /></a>
|
||||
|
||||
Before we released Qlib as an open-source project on Github in Sep 2020, Qlib is an internal project in our group. Unfortunately, the internal commit history is not kept. A lot of members in our group have also contributed a lot to Qlib, which includes Ruihua Wang, Yinda Zhang, Haisu Yu, Shuyu Wang, Bochen Pang, and [Dong Zhou](https://github.com/evanzd/evanzd). Especially thanks to [Dong Zhou](https://github.com/evanzd/evanzd) due to his initial version of Qlib.
|
||||
|
||||
## Guidance
|
||||
|
||||
This project welcomes contributions and suggestions.
|
||||
**Here are some
|
||||
[code standards](docs/developer/code_standard.rst) for submiting a pull request.**
|
||||
[code standards](docs/developer/code_standard.rst) when you submit a pull request.**
|
||||
|
||||
Making contributions is not a hard thing. Solving an issue(maybe just answering a question raised in [issues list](https://github.com/microsoft/qlib/issues) or [gitter](https://gitter.im/Microsoft/qlib)), fixing/issuing a bug, improving the documents and even fixing a typo are important contributions to Qlib.
|
||||
|
||||
For example, if you want to contribute to Qlib's document/code, you can follow the steps in the figure below.
|
||||
If you want to contribute to Qlib's document, you can follow the steps in the figure below.
|
||||
<p align="center">
|
||||
<img src="https://github.com/demon143/qlib/blob/main/docs/_static/img/change%20doc.gif" />
|
||||
</p>
|
||||
|
||||
|
||||
## Licence
|
||||
Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the right to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.8.0.99
|
||||
0.7.2.99
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 198 KiB |
@@ -11,10 +11,7 @@ Introduction
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.The processes of task generation, model training and combine and collect data are shown in the following figure.
|
||||
|
||||
.. image:: ../_static/img/Task-Gen-Recorder-Collector.svg
|
||||
:align: center
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
|
||||
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
@@ -77,8 +74,6 @@ If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to
|
||||
|
||||
Task Collecting
|
||||
===============
|
||||
Before collecting model training results, you need to use the ``qlib.init`` to specify the path of mlruns.
|
||||
|
||||
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
|
||||
|
||||
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
|
||||
@@ -87,10 +82,8 @@ To collect the results of ``task`` after training, ``Qlib`` provides `Collector
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
|
||||
|
||||
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object.
|
||||
You can set the ensembles you want in the ``Collector``'s process_list.
|
||||
Common ensembles include ``AverageEnsemble`` and ``RollingEnsemble``. Average ensemble is used to ensemble the results of different models in the same time period. Rollingensemble is used to ensemble the results of different models in the same time period
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object
|
||||
|
||||
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
|
||||
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
114
docs/component/backtest.rst
Normal file
114
docs/component/backtest.rst
Normal file
@@ -0,0 +1,114 @@
|
||||
.. _backtest:
|
||||
|
||||
============================================
|
||||
Intraday Trading: Model&Strategy Testing
|
||||
============================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
``Intraday Trading`` is designed to test models and strategies, which help users to check the performance of a custom model/strategy.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
``Intraday Trading`` uses ``Order Executor`` to trade and execute orders output by ``Portfolio Strategy``. ``Order Executor`` is a component in `Qlib Framework <../introduction/introduction.html#framework>`_, which can execute orders. ``VWAP Executor`` and ``Close Executor`` is supported by ``Qlib`` now. In the future, ``Qlib`` will support ``HighFreq Executor`` also.
|
||||
|
||||
|
||||
|
||||
Example
|
||||
===========================
|
||||
|
||||
Users need to generate a `prediction score`(a pandas DataFrame) with MultiIndex<instrument, datetime> and a `score` column. And users need to assign a strategy used in backtest, if strategy is not assigned,
|
||||
a `TopkDropoutStrategy` strategy with `(topk=50, n_drop=5, risk_degree=0.95, limit_threshold=0.0095)` will be used.
|
||||
If ``Strategy`` module is not users' interested part, `TopkDropoutStrategy` is enough.
|
||||
|
||||
The simple example of the default strategy is as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.contrib.evaluate import backtest
|
||||
# pred_score is the prediction score
|
||||
report, positions = backtest(pred_score, topk=50, n_drop=0.5, limit_threshold=0.0095)
|
||||
|
||||
To know more about backtesting with a specific ``Strategy``, please refer to `Portfolio Strategy <strategy.html>`_.
|
||||
|
||||
To know more about the prediction score `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
Prediction Score
|
||||
-----------------
|
||||
|
||||
The `prediction score` is a pandas DataFrame. Its index is <datetime(pd.Timestamp), instrument(str)> and it must
|
||||
contains a `score` column.
|
||||
|
||||
A prediction sample is shown as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
datetime instrument score
|
||||
2019-01-04 SH600000 -0.505488
|
||||
2019-01-04 SZ002531 -0.320391
|
||||
2019-01-04 SZ000999 0.583808
|
||||
2019-01-04 SZ300569 0.819628
|
||||
2019-01-04 SZ001696 -0.137140
|
||||
... ...
|
||||
2019-04-30 SZ000996 -1.027618
|
||||
2019-04-30 SH603127 0.225677
|
||||
2019-04-30 SH603126 0.462443
|
||||
2019-04-30 SH603133 -0.302460
|
||||
2019-04-30 SZ300760 -0.126383
|
||||
|
||||
``Forecast Model`` module can make predictions, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
Backtest Result
|
||||
------------------
|
||||
|
||||
The backtest results are in the following form:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
risk
|
||||
excess_return_without_cost mean 0.000605
|
||||
std 0.005481
|
||||
annualized_return 0.152373
|
||||
information_ratio 1.751319
|
||||
max_drawdown -0.059055
|
||||
excess_return_with_cost mean 0.000410
|
||||
std 0.005478
|
||||
annualized_return 0.103265
|
||||
information_ratio 1.187411
|
||||
max_drawdown -0.075024
|
||||
|
||||
|
||||
|
||||
- `excess_return_without_cost`
|
||||
- `mean`
|
||||
Mean value of the `CAR` (cumulative abnormal return) without cost
|
||||
- `std`
|
||||
The `Standard Deviation` of `CAR` (cumulative abnormal return) without cost.
|
||||
- `annualized_return`
|
||||
The `Annualized Rate` of `CAR` (cumulative abnormal return) without cost.
|
||||
- `information_ratio`
|
||||
The `Information Ratio` without cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) without cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
|
||||
|
||||
- `excess_return_with_cost`
|
||||
- `mean`
|
||||
Mean value of the `CAR` (cumulative abnormal return) series with cost
|
||||
- `std`
|
||||
The `Standard Deviation` of `CAR` (cumulative abnormal return) series with cost.
|
||||
- `annualized_return`
|
||||
The `Annualized Rate` of `CAR` (cumulative abnormal return) with cost.
|
||||
- `information_ratio`
|
||||
The `Information Ratio` with cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) with cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
|
||||
|
||||
|
||||
|
||||
Reference
|
||||
==============
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading <../reference/api.html#module-qlib.contrib.evaluate>`_.
|
||||
@@ -338,7 +338,7 @@ DataHandlerLP
|
||||
|
||||
In addition to use ``Data Handler`` in an automatic workflow with ``qrun``, ``Data Handler`` can be used as an independent module, by which users can easily preprocess data (standardization, remove NaN, etc.) and build datasets.
|
||||
|
||||
In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some learnable ``Processors`` which can learn the parameters of data processing(e.g., parameters for zscore normalization). When new data comes in, these `trained` ``Processors`` can then process the new data and thus processing real-time data in an efficient way becomes possible. More information about ``Processors`` will be listed in the next subsection.
|
||||
In order to achieve so, ``Qlib`` provides a base class `qlib.data.dataset.DataHandlerLP <../reference/api.html#qlib.data.dataset.handler.DataHandlerLP>`_. The core idea of this class is that: we will have some leanable ``Processors`` which can learn the parameters of data processing(e.g., parameters for zscore normalization). When new data comes in, these `trained` ``Processors`` can then process the new data and thus processing real-time data in an efficient way becomes possible. More information about ``Processors`` will be listed in the next subsection.
|
||||
|
||||
|
||||
Interface
|
||||
|
||||
@@ -1,31 +1,120 @@
|
||||
.. _highfreq:
|
||||
|
||||
============================================
|
||||
Design of Nested Decision Execution Framework for High-Frequency Trading
|
||||
Design of hierarchical order execution framework
|
||||
============================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
Introduction
|
||||
===================
|
||||
|
||||
Daily trading (e.g. portfolio management) and intraday trading (e.g. orders execution) are two hot topics in Quant investment and usually studied separately.
|
||||
|
||||
To get the join trading performance of daily and intraday trading, they must interact with each other and run backtest jointly.
|
||||
In order to support the joint backtest strategies in multiple levels, a corresponding framework is required. None of the publicly available high-frequency trading frameworks considers multi-level joint trading, which make the backtesting aforementioned inaccurate.
|
||||
|
||||
Besides backtesting, the optimization of strategies from different levels is not standalone and can be affected by each other.
|
||||
For example, the best portfolio management strategy may change with the performance of order executions(e.g. a portfolio with higher turnover may becomes a better choice when we imporve the order execution strategies).
|
||||
To achieve the overall good performance , it is necessary to consider the interaction of strategies in different level.
|
||||
|
||||
Therefore, building a new framework for trading in multiple levels becomes necessary to solve the various problems mentioned above, for which we designed a nested decision execution framework that consider the interaction of strategies.
|
||||
In order to support reinforcement learning algorithms for high-frequency trading, a corresponding framework is required. None of the publicly available high-frequency trading frameworks now consider multi-layer trading mechanisms, and the currently designed algorithms cannot directly use existing frameworks.
|
||||
In addition to supporting the basic intraday multi-layer trading, the linkage with the day-ahead strategy is also a factor that affects the performance evaluation of the strategy. Different day strategies generate different order distributions and different patterns on different stocks. To verify that high-frequency trading strategies perform well on real trading orders, it is necessary to support day-frequency and high-frequency multi-level linkage trading. In addition to more accurate backtesting of high-frequency trading algorithms, if the distribution of day-frequency orders is considered when training a high-frequency trading model, the algorithm can also be optimized more for product-specific day-frequency orders.
|
||||
Therefore, innovation in the high-frequency trading framework is necessary to solve the various problems mentioned above, for which we designed a hierarchical order execution framework that can link daily-frequency and intra-day trading at different granularities.
|
||||
|
||||
.. image:: ../_static/img/framework.svg
|
||||
|
||||
The design of the framework is shown in the yellow part in the middle of the figure above. Each level consists of ``Trading Agent`` and ``Execution Env``. ``Trading Agent`` has its own data processing module (``Information Extractor``), forecasting module (``Forecast Model``) and decision generator (``Decision Generator``). The trading algorithm generates the decisions by the ``Decision Generator`` based on the forecast signals output by the ``Forecast Module``, and the decisions generated by the trading algorithm are passed to the ``Execution Env``, which returns the execution results.
|
||||
The design of the framework is shown in the figure above. At each layer consists of Trading Agent and Execution Env. The Trading Agent has its own data processing module (Information Extractor), forecasting module (Forecast Model) and decision generator (Decision Generator). The trading algorithm generates the corresponding decisions by the Decision Generator based on the forecast signals output by the Forecast Module, and the decisions generated by the trading algorithm are passed to the Execution Env, which returns the execution results. Here the frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intra-day trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The hierarchical order execution framework is user-defined in terms of hierarchy division and decision frequency, making it easy for users to explore the effects of combining different levels of trading algorithms and breaking down the barriers between different levels of trading algorithm optimization.
|
||||
In addition to the innovation in the framework, the hierarchical order execution framework also takes into account various details of the real backtesting environment, minimizing the differences with the final real environment as much as possible. At the same time, the framework is designed to unify the interface between online and offline (e.g. data pre-processing level supports using the same set of code to process both offline and online data) to reduce the cost of strategy go-live as much as possible.
|
||||
|
||||
Prepare Data
|
||||
===================
|
||||
.. _data:: ../../examples/highfreq/README.md
|
||||
|
||||
The frequency of trading algorithm, decision content and execution environment can be customized by users (e.g. intraday trading, daily-frequency trading, weekly-frequency trading), and the execution environment can be nested with finer-grained trading algorithm and execution environment inside (i.e. sub-workflow in the figure, e.g. daily-frequency orders can be turned into finer-grained decisions by splitting orders within the day). The flexibility of nested decision execution framework makes it easy for users to explore the effects of combining different levels of trading strategies and break down the optimization barriers between different levels of trading algorithm.
|
||||
|
||||
Example
|
||||
===========================
|
||||
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
Here is an example of highfreq execution.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import qlib
|
||||
# init qlib
|
||||
provider_uri_day = "~/.qlib/qlib_data/cn_data"
|
||||
provider_uri_1min = "~/.qlib/qlib_data/cn_data_1min"
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
qlib.init(provider_uri=provider_uri_day, expression_cache=None, dataset_cache=None)
|
||||
|
||||
# data freq and backtest time
|
||||
freq = "1min"
|
||||
inst_list = D.list_instruments(D.instruments("all"), as_list=True)
|
||||
start_time = "2020-01-01"
|
||||
start_time = "2020-01-31"
|
||||
|
||||
When initializing qlib, if the default data is used, then both daily and minute frequency data need to be passed in.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# random order strategy config
|
||||
strategy_config = {
|
||||
"class": "RandomOrderStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"trade_range": TradeRangeByTime("9:30", "15:00"),
|
||||
"sample_ratio": 1.0,
|
||||
"volume_ratio": 0.01,
|
||||
"market": market,
|
||||
},
|
||||
}
|
||||
|
||||
.. code-block:: python
|
||||
# backtest config
|
||||
backtest_config = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"account": 100000000,
|
||||
"benchmark": None,
|
||||
"exchange_kwargs": {
|
||||
"freq": freq,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"codes": market,
|
||||
},
|
||||
"pos_type": "InfPosition", # Position with infinitive position
|
||||
}
|
||||
|
||||
please refer to "../../qlib/backtest".
|
||||
|
||||
.. code-block:: python
|
||||
# excutor config
|
||||
executor_config = {
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": freq,
|
||||
"generate_portfolio_metrics": True,
|
||||
"verbose": False,
|
||||
# "verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"track_data": True,
|
||||
"generate_portfolio_metrics": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
NestedExecutor represents not the innermost layer, the initialization parameters should contain inner_executor and inner_strategy. simulatorExecutor represents the current excutor is the innermost layer, the innermost strategy used here is the TWAP strategy, the framework currently also supports the VWAP strategy
|
||||
|
||||
.. code-block:: python
|
||||
# backtest
|
||||
portfolio_metrics_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
|
||||
|
||||
The metrics of backtest are included in the portfolio_metrics_dict and indicator_dict.
|
||||
|
||||
@@ -12,9 +12,7 @@ Introduction
|
||||
|
||||
Because the components in ``Qlib`` are designed in a loosely-coupled way, ``Portfolio Strategy`` can be used as an independent module also.
|
||||
|
||||
``Qlib`` provides several implemented portfolio strategies. Also, ``Qlib`` supports custom strategy, users can customize strategies according to their own requirements.
|
||||
|
||||
After users specifying the models(forecasting signals) and strategies, running backtest will help users to check the performance of a custom model(forecasting signals)/strategy.
|
||||
``Qlib`` provides several implemented portfolio strategies. Also, ``Qlib`` supports custom strategy, users can customize strategies according to their own needs.
|
||||
|
||||
Base Class & Interface
|
||||
======================
|
||||
@@ -84,206 +82,38 @@ TopkDropoutStrategy
|
||||
|
||||
Usage & Example
|
||||
====================
|
||||
|
||||
First, user can create a model to get trading signals(the variable name is ``pred_score`` in following cases).
|
||||
|
||||
Prediction Score
|
||||
-----------------
|
||||
|
||||
The `prediction score` is a pandas DataFrame. Its index is <datetime(pd.Timestamp), instrument(str)> and it must
|
||||
contains a `score` column.
|
||||
|
||||
A prediction sample is shown as follows.
|
||||
``Portfolio Strategy`` can be specified in the ``Intraday Trading(Backtest)``, the example is as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
datetime instrument score
|
||||
2019-01-04 SH600000 -0.505488
|
||||
2019-01-04 SZ002531 -0.320391
|
||||
2019-01-04 SZ000999 0.583808
|
||||
2019-01-04 SZ300569 0.819628
|
||||
2019-01-04 SZ001696 -0.137140
|
||||
... ...
|
||||
2019-04-30 SZ000996 -1.027618
|
||||
2019-04-30 SH603127 0.225677
|
||||
2019-04-30 SH603126 0.462443
|
||||
2019-04-30 SH603133 -0.302460
|
||||
2019-04-30 SZ300760 -0.126383
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import backtest
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
}
|
||||
BACKTEST_CONFIG = {
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": BENCHMARK,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
|
||||
}
|
||||
# use default strategy
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
|
||||
``Forecast Model`` module can make predictions, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
# pred_score is the `prediction score` output by Model
|
||||
report_normal, positions_normal = backtest(
|
||||
pred_score, strategy=strategy, **BACKTEST_CONFIG
|
||||
)
|
||||
|
||||
Normally, the prediction score is the output of the models. But some models are learned from a label with a different scale. So the scale of the prediction score may be different from your expectation(e.g. the return of instruments).
|
||||
|
||||
Qlib didn't add a step to scale the prediction score to a unified scale. Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order). So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).
|
||||
|
||||
Running backtest
|
||||
-----------------
|
||||
|
||||
- In most cases, users could backtest their portfolio management strategy with ``backtest_daily``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.utils.time import Freq
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.contrib.evaluate import backtest_daily
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
|
||||
# init qlib
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
# pred_score, pd.Series
|
||||
"signal": pred_score,
|
||||
}
|
||||
|
||||
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
report_normal, positions_normal = backtest_daily(
|
||||
start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_obj
|
||||
)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
|
||||
)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
pprint(analysis_df)
|
||||
|
||||
|
||||
|
||||
- If users would like to control their strategies in a more detailed(e.g. users have a more advanced version of executor), user could follow this example.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.utils.time import Freq
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.backtest import backtest, executor
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
|
||||
# init qlib
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
FREQ = "day"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
# pred_score, pd.Series
|
||||
"signal": pred_score,
|
||||
}
|
||||
|
||||
EXECUTOR_CONFIG = {
|
||||
"time_per_step": "day",
|
||||
"generate_portfolio_metrics": True,
|
||||
}
|
||||
|
||||
backtest_config = {
|
||||
"start_time": "2017-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"account": 100000000,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"exchange_kwargs": {
|
||||
"freq": FREQ,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# strategy object
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
# executor object
|
||||
executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG)
|
||||
# backtest
|
||||
portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config)
|
||||
analysis_freq = "{0}{1}".format(*Freq.parse(FREQ))
|
||||
# backtest info
|
||||
report_normal, positions_normal = portfolio_metric_dict.get(analysis_freq)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=analysis_freq
|
||||
)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
# log metrics
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
# print out results
|
||||
pprint(f"The following are analysis results of benchmark return({analysis_freq}).")
|
||||
pprint(risk_analysis(report_normal["bench"], freq=analysis_freq))
|
||||
pprint(f"The following are analysis results of the excess return without cost({analysis_freq}).")
|
||||
pprint(analysis["excess_return_without_cost"])
|
||||
pprint(f"The following are analysis results of the excess return with cost({analysis_freq}).")
|
||||
pprint(analysis["excess_return_with_cost"])
|
||||
|
||||
|
||||
Result
|
||||
------------------
|
||||
|
||||
The backtest results are in the following form:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
risk
|
||||
excess_return_without_cost mean 0.000605
|
||||
std 0.005481
|
||||
annualized_return 0.152373
|
||||
information_ratio 1.751319
|
||||
max_drawdown -0.059055
|
||||
excess_return_with_cost mean 0.000410
|
||||
std 0.005478
|
||||
annualized_return 0.103265
|
||||
information_ratio 1.187411
|
||||
max_drawdown -0.075024
|
||||
|
||||
|
||||
- `excess_return_without_cost`
|
||||
- `mean`
|
||||
Mean value of the `CAR` (cumulative abnormal return) without cost
|
||||
- `std`
|
||||
The `Standard Deviation` of `CAR` (cumulative abnormal return) without cost.
|
||||
- `annualized_return`
|
||||
The `Annualized Rate` of `CAR` (cumulative abnormal return) without cost.
|
||||
- `information_ratio`
|
||||
The `Information Ratio` without cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) without cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
|
||||
|
||||
- `excess_return_with_cost`
|
||||
- `mean`
|
||||
Mean value of the `CAR` (cumulative abnormal return) series with cost
|
||||
- `std`
|
||||
The `Standard Deviation` of `CAR` (cumulative abnormal return) series with cost.
|
||||
- `annualized_return`
|
||||
The `Annualized Rate` of `CAR` (cumulative abnormal return) with cost.
|
||||
- `information_ratio`
|
||||
The `Information Ratio` with cost. please refer to `Information Ratio – IR <https://www.investopedia.com/terms/i/informationratio.asp>`_.
|
||||
- `max_drawdown`
|
||||
The `Maximum Drawdown` of `CAR` (cumulative abnormal return) with cost, please refer to `Maximum Drawdown (MDD) <https://www.investopedia.com/terms/m/maximum-drawdown-mdd.asp>`_.
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
|
||||
|
||||
Reference
|
||||
===================
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
To know more about ``Portfolio Strategy``, please refer to `Strategy API <../reference/api.html#module-qlib.contrib.strategy.strategy>`_.
|
||||
|
||||
@@ -38,8 +38,8 @@ Document Structure
|
||||
Workflow: Workflow Management <component/workflow.rst>
|
||||
Data Layer: Data Framework&Usage <component/data.rst>
|
||||
Forecast Model: Model Training & Prediction <component/model.rst>
|
||||
Portfolio Management and Backtest <component/strategy.rst>
|
||||
Nested Decision Execution: High-Frequency Trading <component/highfreq.rst>
|
||||
Strategy: Portfolio Management <component/strategy.rst>
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
|
||||
@@ -34,14 +34,9 @@ Name Description
|
||||
|
||||
`Workflow` layer `Workflow` layer covers the whole workflow of quantitative investment.
|
||||
`Information Extractor` extracts data for models. `Forecast Model` focuses
|
||||
on producing all kinds of forecast signals (e.g. *alpha*, risk) for other
|
||||
modules. With these signals `Decision Generator` will generate the target
|
||||
trading decisions(i.e. portfolio, orders) to be executed by `Execution Env`
|
||||
(i.e. the trading market). There may be multiple levels of `Trading Agent`
|
||||
and `Execution Env` (e.g. an *order executor trading agent and intraday
|
||||
order execution environment* could behave like an interday trading
|
||||
environment and nested in *daily portfolio management trading agent and
|
||||
interday trading environment* )
|
||||
on producing all kinds of forecast signals (e.g. _alpha_, risk) for other
|
||||
modules. With these signals `Portfolio Generator` will generate the target
|
||||
portfolio and produce orders to be executed by `Order Executor`.
|
||||
|
||||
`Interface` layer `Interface` layer tries to present a user-friendly interface for the underlying
|
||||
system. `Analyser` module will provide users detailed analysis reports of
|
||||
|
||||
@@ -48,7 +48,6 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
- ``qlib.config.REG_CN``: China stock market.
|
||||
|
||||
Different modes will result in different trading limitations and costs.
|
||||
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/main/qlib/config.py#L239>`_. Users can set the key configurations manually if the existing region setting can't meet their requirements.
|
||||
- `redis_host`
|
||||
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
|
||||
The lock and cache mechanism relies on redis.
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
# AdaRNN
|
||||
* Code: [https://github.com/jindongwang/transferlearning/tree/master/code/deep/adarnn](https://github.com/jindongwang/transferlearning/tree/master/code/deep/adarnn)
|
||||
* Paper: [AdaRNN: Adaptive Learning and Forecasting for Time Series](https://arxiv.org/pdf/2108.04443.pdf).
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -1,88 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: ADARNN
|
||||
module_path: qlib.contrib.model.pytorch_adarnn
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,3 +0,0 @@
|
||||
# ADD
|
||||
* Paper: [ADD: Augmented Disentanglement Distillation Framework for Improving Stock Trend Forecasting](https://arxiv.org/abs/2012.06289).
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -1,94 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: ADD
|
||||
module_path: qlib.contrib.model.pytorch_add
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.1
|
||||
dec_dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 5000
|
||||
metric: ic
|
||||
base_model: GRU
|
||||
gamma: 0.1
|
||||
gamma_clip: 0.2
|
||||
optimizer: adam
|
||||
mu: 0.2
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,2 +0,0 @@
|
||||
# Gated Recurrent Unit (GRU)
|
||||
* Paper: [Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation](https://aclanthology.org/D14-1179.pdf).
|
||||
@@ -1,2 +0,0 @@
|
||||
# Long Short-Term Memory (LSTM)
|
||||
* Paper: [Long Short-Term Memory](https://direct.mit.edu/neco/article-abstract/9/8/1735/6109/Long-Short-Term-Memory?redirectedFrom=fulltext).
|
||||
@@ -1 +0,0 @@
|
||||
# Localformer
|
||||
@@ -1 +0,0 @@
|
||||
# Multi-Layer Perceptron (MLP)
|
||||
@@ -1,9 +1,4 @@
|
||||
# Benchmarks Performance
|
||||
This page lists a batch of methods designed for alpha seeking. Each method tries to give scores/predictions for all stocks each day(e.g. forecasting the future excess return of stocks). The scores/predictions of the models will be used as the mined alpha. Investing in stocks with higher scores is expected to yield more profit.
|
||||
|
||||
The alpha is evaluated in two ways.
|
||||
1. The correlation between the alpha and future return.
|
||||
1. Constructing portfolio based on the alpha and evaluating the final total return.
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
|
||||
@@ -21,7 +16,6 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 |
|
||||
| TabNet(Sercan O. Arik, et al.) | Alpha158 | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04 | 0.3676±0.54 | -0.1089±0.08 |
|
||||
| Transformer(Ashish Vaswani, et al.) | Alpha158 | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02 | 0.3970±0.26 | -0.1101±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02 | 0.5160±0.25 | -0.1017±0.02 |
|
||||
@@ -41,6 +35,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
|
||||
|
||||
|
||||
|
||||
## Alpha360 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -53,12 +48,9 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost(Tianqi Chen, et al.) | Alpha360 | 0.0394±0.00 | 0.2909±0.00 | 0.0448±0.00 | 0.3679±0.00 | 0.0344±0.00 | 0.4527±0.02 | -0.1004±0.00 |
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha360 | 0.0404±0.00 | 0.3023±0.00 | 0.0495±0.00 | 0.3898±0.00 | 0.0468±0.01 | 0.6302±0.20 | -0.0860±0.01 |
|
||||
| LightGBM(Guolin Ke, et al.) | Alpha360 | 0.0400±0.00 | 0.3037±0.00 | 0.0499±0.00 | 0.4042±0.00 | 0.0558±0.00 | 0.7632±0.00 | -0.0659±0.00 |
|
||||
| TCN(Shaojie Bai, et al.) | Alpha360 | 0.0441±0.00 | 0.3301±0.02 | 0.0519±0.00 | 0.4130±0.01 | 0.0604±0.02 | 0.8295±0.34 | -0.1018±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0497±0.00 | 0.3829±0.04 | 0.0599±0.00 | 0.4736±0.03 | 0.0626±0.02 | 0.8651±0.31 | -0.0994±0.03 |
|
||||
| LSTM(Sepp Hochreiter, et al.) | Alpha360 | 0.0448±0.00 | 0.3474±0.04 | 0.0549±0.00 | 0.4366±0.03 | 0.0647±0.03 | 0.8963±0.39 | -0.0875±0.02 |
|
||||
| ADD | Alpha360 | 0.0430±0.00 | 0.3188±0.04 | 0.0559±0.00 | 0.4301±0.03 | 0.0667±0.02 | 0.8992±0.34 | -0.0855±0.02 |
|
||||
| GRU(Kyunghyun Cho, et al.) | Alpha360 | 0.0493±0.00 | 0.3772±0.04 | 0.0584±0.00 | 0.4638±0.03 | 0.0720±0.02 | 0.9730±0.33 | -0.0821±0.02 |
|
||||
| AdaRNN(Yuntao Du, et al.) | Alpha360 | 0.0464±0.01 | 0.3619±0.08 | 0.0539±0.01 | 0.4287±0.06 | 0.0753±0.03 | 1.0200±0.40 | -0.0936±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0476±0.00 | 0.3508±0.02 | 0.0598±0.00 | 0.4604±0.01 | 0.0824±0.02 | 1.1079±0.26 | -0.0894±0.03 |
|
||||
| TCTS(Xueqing Wu, et al.) | Alpha360 | 0.0508±0.00 | 0.3931±0.04 | 0.0599±0.00 | 0.4756±0.03 | 0.0893±0.03 | 1.2256±0.36 | -0.0857±0.02 |
|
||||
| TRA(Hengxu Lin, et al.) | Alpha360 | 0.0485±0.00 | 0.3787±0.03 | 0.0587±0.00 | 0.4756±0.03 | 0.0920±0.03 | 1.2789±0.42 | -0.0834±0.02 |
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
# TCN
|
||||
* Code: [https://github.com/locuslab/TCN](https://github.com/locuslab/TCN)
|
||||
* Paper: [An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling](https://arxiv.org/abs/1803.01271).
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
numpy==1.17.4
|
||||
pandas==1.1.2
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -1,100 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCN
|
||||
module_path: qlib.contrib.model.pytorch_tcn_ts
|
||||
kwargs:
|
||||
d_feat: 20
|
||||
num_layers: 5
|
||||
n_chans: 32
|
||||
kernel_size: 7
|
||||
dropout: 0.5
|
||||
n_epochs: 200
|
||||
lr: 1e-4
|
||||
early_stop: 20
|
||||
batch_size: 2000
|
||||
metric: loss
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -1,90 +0,0 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCN
|
||||
module_path: qlib.contrib.model.pytorch_tcn
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
num_layers: 5
|
||||
n_chans: 128
|
||||
kernel_size: 3
|
||||
dropout: 0.5
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 2000
|
||||
metric: loss
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
GPU: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -95,4 +95,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
# TabNet
|
||||
* Code: [https://github.com/dreamquark-ai/tabnet](https://github.com/dreamquark-ai/tabnet)
|
||||
* Paper: [TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/pdf/1908.07442.pdf).
|
||||
@@ -1,3 +0,0 @@
|
||||
# Transformer
|
||||
* Code: [https://github.com/tensorflow/tensor2tensor](https://github.com/tensorflow/tensor2tensor)
|
||||
* Paper: [Attention is All you Need](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf).
|
||||
@@ -1,2 +0,0 @@
|
||||
# Introduction
|
||||
The examples in this folder try to demonstrate some common usage of data-related modules of Qlib
|
||||
@@ -1,53 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
import subprocess
|
||||
import yaml
|
||||
from qlib.log import TimeInspector
|
||||
|
||||
from qlib import init
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
# For general purpose, we use relative path
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
if __name__ == "__main__":
|
||||
init()
|
||||
|
||||
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
||||
|
||||
# 1) show original time
|
||||
with TimeInspector.logt("The original time without handler cache:"):
|
||||
subprocess.run(f"qrun {config_path}", shell=True)
|
||||
|
||||
# 2) dump handler
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
||||
pprint(hd_conf)
|
||||
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
||||
hd_path = DIRNAME / "handler.pkl"
|
||||
hd.to_pickle(hd_path, dump_all=True)
|
||||
|
||||
# 3) create new task with handler cache
|
||||
new_task_config = deepcopy(task_config)
|
||||
new_task_config["task"]["dataset"]["kwargs"]["handler"] = f"file://{hd_path}"
|
||||
new_task_config["sys"] = {"path": [str(config_path.parent.resolve())]}
|
||||
new_task_path = DIRNAME / "new_task.yaml"
|
||||
print("The location of the new task", new_task_path)
|
||||
|
||||
# save new task
|
||||
with new_task_path.open("w") as f:
|
||||
yaml.safe_dump(new_task_config, f, indent=4, sort_keys=False)
|
||||
|
||||
# 4) train model with new task
|
||||
with TimeInspector.logt("The time for task with handler cache:"):
|
||||
subprocess.run(f"qrun {new_task_path}", shell=True)
|
||||
@@ -1,59 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The motivation of this demo
|
||||
- To show the data modules of Qlib is Serializable, users can dump processed data to disk to avoid duplicated data preprocessing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from pprint import pprint
|
||||
import subprocess
|
||||
|
||||
import yaml
|
||||
|
||||
from qlib import init
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
# For general purpose, we use relative path
|
||||
DIRNAME = Path(__file__).absolute().resolve().parent
|
||||
|
||||
if __name__ == "__main__":
|
||||
init()
|
||||
|
||||
repeat = 2
|
||||
exp_name = "data_mem_reuse_demo"
|
||||
|
||||
config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml"
|
||||
task_config = yaml.safe_load(config_path.open())
|
||||
|
||||
# 1) without using processed data in memory
|
||||
with TimeInspector.logt("The original time without reusing processed data in memory:"):
|
||||
for i in range(repeat):
|
||||
task_train(task_config["task"], experiment_name=exp_name)
|
||||
|
||||
# 2) prepare processed data in memory.
|
||||
hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"]
|
||||
pprint(hd_conf)
|
||||
hd: DataHandlerLP = init_instance_by_config(hd_conf)
|
||||
|
||||
# 3) with reusing processed data in memory
|
||||
new_task = deepcopy(task_config["task"])
|
||||
new_task["dataset"]["kwargs"]["handler"] = hd
|
||||
print(new_task)
|
||||
|
||||
with TimeInspector.logt("The time with reusing processed data in memory:"):
|
||||
# this will save the time to reload and process data from disk(in `DataHandlerLP`)
|
||||
# It still takes a lot of time in the backtest phase
|
||||
for i in range(repeat):
|
||||
task_train(new_task, experiment_name=exp_name)
|
||||
|
||||
# 4) User can change other parts exclude processed data in memory(handler)
|
||||
new_task = deepcopy(task_config["task"])
|
||||
new_task["dataset"]["kwargs"]["segments"]["train"] = ("20100101", "20131231")
|
||||
with TimeInspector.logt("The time with reusing processed data in memory:"):
|
||||
task_train(new_task, experiment_name=exp_name)
|
||||
@@ -1,20 +1,15 @@
|
||||
# Introduction
|
||||
This folder contains 2 examples
|
||||
- A high-frequency dataset example
|
||||
- An example of predicting the price trend in high-frequency data
|
||||
|
||||
## High-Frequency Dataset
|
||||
# High-Frequency Dataset
|
||||
|
||||
This dataset is an example for RL high frequency trading.
|
||||
|
||||
### Get High-Frequency Data
|
||||
## Get High-Frequency Data
|
||||
|
||||
Get high-frequency data by running the following command:
|
||||
```bash
|
||||
python workflow.py get_data
|
||||
```
|
||||
|
||||
### Dump & Reload & Reinitialize the Dataset
|
||||
## Dump & Reload & Reinitialize the Dataset
|
||||
|
||||
|
||||
The High-Frequency Dataset is implemented as `qlib.data.dataset.DatasetH` in the `workflow.py`. `DatatsetH` is the subclass of [`qlib.utils.serial.Serializable`](https://qlib.readthedocs.io/en/latest/advanced/serial.html), whose state can be dumped in or loaded from disk in `pickle` format.
|
||||
@@ -32,10 +27,9 @@ Run the example by running the following command:
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
|
||||
## Benchmarks Performance (predicting the price trend in high-frequency data)
|
||||
|
||||
Here are the results of models for predicting the price trend in high-frequency data. We will keep updating benchmark models in future.
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.0349±0.00 | 0.3805±0.00| 0.0435±0.00 | 0.4724±0.00 | 0.5111±0.00 | 0.5428±0.00 | 0.000074±0.00 | 0.2677±0.00 |
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
|
||||
@@ -59,7 +59,7 @@ task:
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
kwargs:
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
@@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerR, TrainerRM, task_train
|
||||
from qlib.model.trainer import TrainerRM, task_train
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class RollingTaskExample:
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool=None, # if user want to "rolling_task"
|
||||
task_pool="rolling_task",
|
||||
task_config=None,
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
@@ -43,19 +43,14 @@ class RollingTaskExample:
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
if task_pool is None:
|
||||
self.trainer = TrainerR(experiment_name=self.experiment_name)
|
||||
else:
|
||||
self.task_pool = task_pool
|
||||
self.trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
self.task_pool = task_pool
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
@@ -71,10 +66,10 @@ class RollingTaskExample:
|
||||
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
self.trainer.train(tasks)
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# NOTE: this is only used for TrainerRM
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
@@ -1,105 +1,9 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
The expect result of `backtest` is following in current version
|
||||
|
||||
'The following are analysis results of benchmark return(1day).'
|
||||
risk
|
||||
mean 0.000651
|
||||
std 0.012472
|
||||
annualized_return 0.154967
|
||||
information_ratio 0.805422
|
||||
max_drawdown -0.160445
|
||||
'The following are analysis results of the excess return without cost(1day).'
|
||||
risk
|
||||
mean 0.001258
|
||||
std 0.007575
|
||||
annualized_return 0.299303
|
||||
information_ratio 2.561219
|
||||
max_drawdown -0.068386
|
||||
'The following are analysis results of the excess return with cost(1day).'
|
||||
risk
|
||||
mean 0.001110
|
||||
std 0.007575
|
||||
annualized_return 0.264280
|
||||
information_ratio 2.261392
|
||||
max_drawdown -0.071842
|
||||
[1706497:MainThread](2021-12-07 14:08:30,263) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_30minute.
|
||||
pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of benchmark return(30minute).'
|
||||
risk
|
||||
mean 0.000078
|
||||
std 0.003646
|
||||
annualized_return 0.148787
|
||||
information_ratio 0.935252
|
||||
max_drawdown -0.142830
|
||||
('The following are analysis results of the excess return without '
|
||||
'cost(30minute).')
|
||||
risk
|
||||
mean 0.000174
|
||||
std 0.003343
|
||||
annualized_return 0.331867
|
||||
information_ratio 2.275019
|
||||
max_drawdown -0.074752
|
||||
'The following are analysis results of the excess return with cost(30minute).'
|
||||
risk
|
||||
mean 0.000155
|
||||
std 0.003343
|
||||
annualized_return 0.294536
|
||||
information_ratio 2.018860
|
||||
max_drawdown -0.075579
|
||||
[1706497:MainThread](2021-12-07 14:08:30,277) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_5minute.p
|
||||
kl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of benchmark return(5minute).'
|
||||
risk
|
||||
mean 0.000015
|
||||
std 0.001460
|
||||
annualized_return 0.172170
|
||||
information_ratio 1.103439
|
||||
max_drawdown -0.144807
|
||||
'The following are analysis results of the excess return without cost(5minute).'
|
||||
risk
|
||||
mean 0.000028
|
||||
std 0.001412
|
||||
annualized_return 0.319771
|
||||
information_ratio 2.119563
|
||||
max_drawdown -0.077426
|
||||
'The following are analysis results of the excess return with cost(5minute).'
|
||||
risk
|
||||
mean 0.000025
|
||||
std 0.001412
|
||||
annualized_return 0.281536
|
||||
information_ratio 1.866091
|
||||
max_drawdown -0.078194
|
||||
[1706497:MainThread](2021-12-07 14:08:30,287) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_1day
|
||||
.pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of indicators(1day).'
|
||||
value
|
||||
ffr 0.945821
|
||||
pa 0.000324
|
||||
pos 0.542882
|
||||
[1706497:MainThread](2021-12-07 14:08:30,293) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_30mi
|
||||
nute.pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of indicators(30minute).'
|
||||
value
|
||||
ffr 0.982910
|
||||
pa 0.000037
|
||||
pos 0.500806
|
||||
[1706497:MainThread](2021-12-07 14:08:30,302) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_5min
|
||||
ute.pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of indicators(5minute).'
|
||||
value
|
||||
ffr 0.991017
|
||||
pa 0.000000
|
||||
pos 0.000000
|
||||
[1706497:MainThread](2021-12-07 14:08:30,627) INFO - qlib.timer - [log.py:113] - Time cost: 0.014s | waiting `async_log` Done
|
||||
"""
|
||||
|
||||
|
||||
from copy import deepcopy
|
||||
import qlib
|
||||
import fire
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
|
||||
from qlib.data import D
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
@@ -110,6 +14,7 @@ from qlib.backtest import collect_data
|
||||
|
||||
|
||||
class NestedDecisionExecutionWorkflow:
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
data_handler_config = {
|
||||
@@ -257,10 +162,13 @@ class NestedDecisionExecutionWorkflow:
|
||||
self.port_analysis_config["backtest"]["benchmark"] = self.benchmark
|
||||
|
||||
with R.start(experiment_name="backtest"):
|
||||
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(
|
||||
recorder,
|
||||
self.port_analysis_config,
|
||||
risk_analysis_freq=["day", "30min", "5min"],
|
||||
indicator_analysis_freq=["day", "30min", "5min"],
|
||||
indicator_analysis_method="value_weighted",
|
||||
)
|
||||
par.generate()
|
||||
@@ -291,101 +199,6 @@ class NestedDecisionExecutionWorkflow:
|
||||
for trade_decision in data_generator:
|
||||
print(trade_decision)
|
||||
|
||||
# the code below are for checking, users don't have to care about it
|
||||
# The tests can be categorized into 2 types
|
||||
# 1) comparing same backtest
|
||||
# - Basic test idea: the shared accumulated value are equal in multiple levels
|
||||
# - Aligning the profit calculation between multiple levels and single levels.
|
||||
# 2) comparing different backtest
|
||||
# - Basic test idea:
|
||||
# - the daily backtest will be similar as multi-level(the data quality makes this gap samller)
|
||||
|
||||
def check_diff_freq(self):
|
||||
self._init_qlib()
|
||||
exp = R.get_exp(experiment_name="backtest")
|
||||
rec = next(iter(exp.list_recorders().values())) # assuming this will get the latest recorder
|
||||
for check_key in "account", "total_turnover", "total_cost":
|
||||
check_key = "total_cost"
|
||||
|
||||
acc_dict = {}
|
||||
for freq in ["30minute", "5minute", "1day"]:
|
||||
acc_dict[freq] = rec.load_object(f"portfolio_analysis/report_normal_{freq}.pkl")[check_key]
|
||||
acc_df = pd.DataFrame(acc_dict)
|
||||
acc_resam = acc_df.resample("1d").last().dropna()
|
||||
assert (acc_resam["30minute"] == acc_resam["1day"]).all()
|
||||
|
||||
def backtest_only_daily(self):
|
||||
"""
|
||||
This backtest is used for comparing the nested execution and single layer execution
|
||||
Due to the low quality daily-level and miniute-level data, they are hardly comparable.
|
||||
So it is used for detecting serious bugs which make the results different greatly.
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
[1724971:MainThread](2021-12-07 16:24:31,156) INFO - qlib.workflow - [record_temp.py:441] - Portfolio analysis record 'port_analysis_1day.pkl'
|
||||
has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of benchmark return(1day).'
|
||||
risk
|
||||
mean 0.000651
|
||||
std 0.012472
|
||||
annualized_return 0.154967
|
||||
information_ratio 0.805422
|
||||
max_drawdown -0.160445
|
||||
'The following are analysis results of the excess return without cost(1day).'
|
||||
risk
|
||||
mean 0.001375
|
||||
std 0.006103
|
||||
annualized_return 0.327204
|
||||
information_ratio 3.475016
|
||||
max_drawdown -0.024927
|
||||
'The following are analysis results of the excess return with cost(1day).'
|
||||
risk
|
||||
mean 0.001184
|
||||
std 0.006091
|
||||
annualized_return 0.281801
|
||||
information_ratio 2.998749
|
||||
max_drawdown -0.029568
|
||||
[1724971:MainThread](2021-12-07 16:24:31,170) INFO - qlib.workflow - [record_temp.py:466] - Indicator analysis record 'indicator_analysis_1day.
|
||||
pkl' has been saved as the artifact of the Experiment 2
|
||||
'The following are analysis results of indicators(1day).'
|
||||
value
|
||||
ffr 1.0
|
||||
pa 0.0
|
||||
pos 0.0
|
||||
[1724971:MainThread](2021-12-07 16:24:31,188) INFO - qlib.timer - [log.py:113] - Time cost: 0.007s | waiting `async_log` Done
|
||||
|
||||
"""
|
||||
self._init_qlib()
|
||||
model = init_instance_by_config(self.task["model"])
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
self._train_model(model, dataset)
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}
|
||||
pa_conf = deepcopy(self.port_analysis_config)
|
||||
pa_conf["strategy"] = strategy_config
|
||||
pa_conf["executor"] = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"generate_portfolio_metrics": True,
|
||||
"verbose": True,
|
||||
},
|
||||
}
|
||||
pa_conf["backtest"]["benchmark"] = self.benchmark
|
||||
|
||||
with R.start(experiment_name="backtest"):
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(recorder, pa_conf)
|
||||
par.generate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(NestedDecisionExecutionWorkflow)
|
||||
|
||||
@@ -248,7 +248,7 @@ class ModelRunner:
|
||||
determines the dataset to be used for each model.
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path (NOTE: the local path must be a absolute path)
|
||||
it could be url on the we or local path
|
||||
exp_folder_name: str
|
||||
the name of the experiment folder
|
||||
wait_before_rm_env : bool
|
||||
|
||||
@@ -62,6 +62,12 @@
|
||||
"import qlib\n",
|
||||
"import pandas as pd\n",
|
||||
"from qlib.config import REG_CN\n",
|
||||
"from qlib.contrib.model.gbdt import LGBModel\n",
|
||||
"from qlib.contrib.data.handler import Alpha158\n",
|
||||
"from qlib.contrib.evaluate import (\n",
|
||||
" backtest as normal_backtest,\n",
|
||||
" risk_analysis,\n",
|
||||
")\n",
|
||||
"from qlib.utils import exists_qlib_data, init_instance_by_config\n",
|
||||
"from qlib.workflow import R\n",
|
||||
"from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
|
||||
@@ -198,7 +204,7 @@
|
||||
" },\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"class\": \"TopkDropoutStrategy\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"model\": model,\n",
|
||||
" \"dataset\": dataset,\n",
|
||||
|
||||
@@ -5,7 +5,7 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord, SigAnaRecord
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
@@ -70,10 +70,6 @@ if __name__ == "__main__":
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# Signal Analysis
|
||||
sar = SigAnaRecord(recorder)
|
||||
sar.generate()
|
||||
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
par = PortAnaRecord(recorder, port_analysis_config, "day")
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.8.0.99"
|
||||
_version_path = Path(__file__).absolute().parent / "VERSION.txt" # This file is copyed from setup.py
|
||||
__version__ = _version_path.read_text(encoding="utf-8").strip()
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
@@ -15,16 +16,6 @@ from .log import get_module_logger
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs :
|
||||
clear_mem_cache: str
|
||||
the default value is True;
|
||||
Will the memory cache be clear.
|
||||
It is often used to improve performance when init will be called for multiple times
|
||||
"""
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
|
||||
@@ -38,9 +29,7 @@ def init(default_conf="client", **kwargs):
|
||||
logger.warning("Skip initialization because `skip_if_reg is True`")
|
||||
return
|
||||
|
||||
clear_mem_cache = kwargs.pop("clear_mem_cache", True)
|
||||
if clear_mem_cache:
|
||||
H.clear()
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# mount nfs
|
||||
|
||||
@@ -50,12 +50,11 @@ def get_exchange(
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost. It is a ratio. The cost is proportional to your order's deal amount.
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost. It is a ratio. The cost is proportional to your order's deal amount.
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
|
||||
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
@@ -186,10 +185,8 @@ def get_strategy_executor(
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
trade_strategy.reset_common_infra(common_infra)
|
||||
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor)
|
||||
trade_executor.reset_common_infra(common_infra)
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
|
||||
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
|
||||
|
||||
return trade_strategy, trade_executor
|
||||
|
||||
|
||||
@@ -29,10 +29,7 @@ rtn & earning in the Account
|
||||
|
||||
|
||||
class AccumulatedInfo:
|
||||
"""
|
||||
accumulated trading info, including accumulated return/cost/turnover
|
||||
AccumulatedInfo should be shared accross different levels
|
||||
"""
|
||||
"""accumulated trading info, including accumulated return/cost/turnover"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -65,11 +62,6 @@ class AccumulatedInfo:
|
||||
|
||||
|
||||
class Account:
|
||||
"""
|
||||
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in qlib/backtest/executor.py:NestedExecutor
|
||||
Different level of executor has different Account object when calculating metrics. But the position object is shared cross all the Account object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_cash: float = 1e9,
|
||||
@@ -103,8 +95,6 @@ class Account:
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
|
||||
# 1) the following variables are shared by multiple layers
|
||||
# - you will see a shallow copy instead of deepcopy in the NestedExecutor;
|
||||
self.init_cash = init_cash
|
||||
self.current_position: BasePosition = init_instance_by_config(
|
||||
{
|
||||
@@ -116,9 +106,6 @@ class Account:
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
)
|
||||
self.accum_info = AccumulatedInfo()
|
||||
|
||||
# 2) following variables are not shared between layers
|
||||
self.portfolio_metrics = None
|
||||
self.hist_positions = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||
@@ -132,8 +119,7 @@ class Account:
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
# NOTE:
|
||||
# `accum_info` and `current_position` are shared here
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.portfolio_metrics = PortfolioMetrics(freq, benchmark_config)
|
||||
self.hist_positions = {}
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ class Exchange:
|
||||
self.extra_quote["limit_buy"] = False
|
||||
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
|
||||
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
|
||||
self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0)
|
||||
self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0)
|
||||
|
||||
LT_TP_EXP = "(exp)" # Tuple[str, str]
|
||||
LT_FLT = "float" # float
|
||||
@@ -401,9 +401,9 @@ class Exchange:
|
||||
def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time, method="sum"):
|
||||
def get_volume(self, stock_id, start_time, end_time):
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method="sum")
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
|
||||
if direction == OrderDir.SELL:
|
||||
@@ -736,11 +736,7 @@ class Exchange:
|
||||
|
||||
# TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps
|
||||
trade_val = order.deal_amount * trade_price
|
||||
if not total_trade_val or np.isnan(total_trade_val):
|
||||
# TODO: assert trade_val == 0, f"trade_val != 0, total_trade_val: {total_trade_val}; order info: {order}"
|
||||
adj_cost_ratio = self.impact_cost
|
||||
else:
|
||||
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
|
||||
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
|
||||
|
||||
if order.direction == Order.SELL:
|
||||
cost_ratio = self.close_cost + adj_cost_ratio
|
||||
|
||||
@@ -130,7 +130,7 @@ class BaseExecutor:
|
||||
|
||||
if common_infra.has("trade_account"):
|
||||
# NOTE: there is a trick in the code.
|
||||
# shallow copy is used instead of deepcopy. So positions are shared
|
||||
# copy is used instead of deepcopy. So positions are shared
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||
|
||||
@@ -395,25 +395,9 @@ class NestedExecutor(BaseExecutor):
|
||||
if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx:
|
||||
# if force align the range limit, skip the steps outside the decision range limit
|
||||
|
||||
res = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
|
||||
# NOTE: !!!!!
|
||||
# the two lines below is for a special case in RL
|
||||
# To solve the confliction below
|
||||
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop
|
||||
# For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)])
|
||||
# - However, RL-based framework has it's own script to run the loop
|
||||
# For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])
|
||||
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution below is proposed
|
||||
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of RL Framework
|
||||
# - Each step of (RL Env) will make (inner Qlib Executor) one step forward
|
||||
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy
|
||||
# So the two lines below is the implementation of yielding control rights
|
||||
if isinstance(res, GeneratorType):
|
||||
res = yield from res
|
||||
|
||||
_inner_trade_decision: BaseTradeDecision = res
|
||||
|
||||
_inner_trade_decision: BaseTradeDecision = self.inner_strategy.generate_trade_decision(
|
||||
_inner_execute_result
|
||||
)
|
||||
trade_decision.mod_inner_decision(_inner_trade_decision) # propagate part of decision information
|
||||
|
||||
# NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting
|
||||
@@ -423,7 +407,6 @@ class NestedExecutor(BaseExecutor):
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(
|
||||
trade_decision=_inner_trade_decision, level=level + 1
|
||||
)
|
||||
self.post_inner_exe_step(_inner_execute_result)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
|
||||
inner_order_indicators.append(
|
||||
@@ -435,17 +418,6 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
|
||||
|
||||
def post_inner_exe_step(self, inner_exe_res):
|
||||
"""
|
||||
A hook for doing sth after each step of inner strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
@@ -223,12 +223,6 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
||||
|
||||
def __str__(self):
|
||||
return self.__dict__.__str__()
|
||||
|
||||
def __repr__(self):
|
||||
return self.__dict__.__repr__()
|
||||
|
||||
|
||||
class Position(BasePosition):
|
||||
"""Position
|
||||
|
||||
@@ -55,9 +55,9 @@ class TradeCalendarManager:
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
|
||||
_calendar = Cal.calendar(freq=freq, future=True)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
self._calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True)
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq)
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
@@ -70,7 +70,7 @@ class TradeCalendarManager:
|
||||
- If self.trade_step >= self.self.trade_len, it means the trading is finished
|
||||
- If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step
|
||||
"""
|
||||
return self.trade_step >= self.trade_len
|
||||
return self.trade_step >= self.trade_len - 1
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
@@ -222,7 +222,7 @@ class CommonInfrastructure(BaseInfrastructure):
|
||||
|
||||
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
"""level infrastructure is created by executor, and then shared to strategies on the same level"""
|
||||
"""level instrastructure is created by executor, and then shared to strategies on the same level"""
|
||||
|
||||
def get_support_infra(self):
|
||||
"""
|
||||
|
||||
@@ -10,7 +10,6 @@ Two modes are supported
|
||||
- server
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
@@ -19,11 +18,7 @@ import logging
|
||||
import platform
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.utils.time import Freq
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -78,9 +73,6 @@ class Config:
|
||||
REG_CN = "cn"
|
||||
REG_US = "us"
|
||||
|
||||
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||
PROTOCOL_VERSION = 4
|
||||
|
||||
NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)
|
||||
|
||||
DISK_DATASET_CACHE = "DiskDatasetCache"
|
||||
@@ -115,8 +107,6 @@ _default_config = {
|
||||
# for simple dataset cache
|
||||
"local_cache_path": None,
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# pickle.dump protocol version
|
||||
"dump_protocol_version": PROTOCOL_VERSION,
|
||||
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
|
||||
"maxtasksperchild": None,
|
||||
# If joblib_backend is None, use loky
|
||||
@@ -249,8 +239,8 @@ HIGH_FREQ_CONFIG = {
|
||||
_default_region_config = {
|
||||
REG_CN: {
|
||||
"trade_unit": 100,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"limit_threshold": 0.099,
|
||||
"deal_price": "vwap",
|
||||
},
|
||||
REG_US: {
|
||||
"trade_unit": 1,
|
||||
@@ -275,20 +265,6 @@ class QlibConfig(Config):
|
||||
self.provider_uri = provider_uri
|
||||
self.mount_path = mount_path
|
||||
|
||||
@staticmethod
|
||||
def format_provider_uri(provider_uri: Union[str, dict, Path]) -> dict:
|
||||
if provider_uri is None:
|
||||
raise ValueError("provider_uri cannot be None")
|
||||
if isinstance(provider_uri, (str, dict, Path)):
|
||||
if not isinstance(provider_uri, dict):
|
||||
provider_uri = {QlibConfig.DEFAULT_FREQ: provider_uri}
|
||||
else:
|
||||
raise TypeError(f"provider_uri does not support {type(provider_uri)}")
|
||||
for freq, _uri in provider_uri.items():
|
||||
if QlibConfig.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
provider_uri[freq] = str(Path(_uri).expanduser().resolve())
|
||||
return provider_uri
|
||||
|
||||
@staticmethod
|
||||
def get_uri_type(uri: Union[str, Path]):
|
||||
uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve())
|
||||
@@ -301,9 +277,7 @@ class QlibConfig(Config):
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def get_data_uri(self, freq: Optional[Union[str, Freq]] = None) -> Path:
|
||||
if freq is not None:
|
||||
freq = str(freq) # converting Freq to string
|
||||
def get_data_uri(self, freq: str = None) -> Path:
|
||||
if freq is None or freq not in self.provider_uri:
|
||||
freq = QlibConfig.DEFAULT_FREQ
|
||||
_provider_uri = self.provider_uri[freq]
|
||||
@@ -337,7 +311,11 @@ class QlibConfig(Config):
|
||||
def resolve_path(self):
|
||||
# resolve path
|
||||
_mount_path = self["mount_path"]
|
||||
_provider_uri = self.DataPathManager.format_provider_uri(self["provider_uri"])
|
||||
_provider_uri = self["provider_uri"]
|
||||
if _provider_uri is None:
|
||||
raise ValueError("provider_uri cannot be None")
|
||||
if not isinstance(_provider_uri, dict):
|
||||
_provider_uri = {self.DEFAULT_FREQ: _provider_uri}
|
||||
if not isinstance(_mount_path, dict):
|
||||
_mount_path = {_freq: _mount_path for _freq in _provider_uri.keys()}
|
||||
|
||||
@@ -346,7 +324,10 @@ class QlibConfig(Config):
|
||||
assert len(_miss_freq) == 0, f"mount_path is missing freq: {_miss_freq}"
|
||||
|
||||
# resolve
|
||||
for _freq in _provider_uri.keys():
|
||||
for _freq, _uri in _provider_uri.items():
|
||||
# provider_uri
|
||||
if self.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI:
|
||||
_provider_uri[_freq] = str(Path(_uri).expanduser().resolve())
|
||||
# mount_path
|
||||
_mount_path[_freq] = (
|
||||
_mount_path[_freq]
|
||||
@@ -356,6 +337,20 @@ class QlibConfig(Config):
|
||||
self["provider_uri"] = _provider_uri
|
||||
self["mount_path"] = _mount_path
|
||||
|
||||
def get_uri_type(self):
|
||||
path = self["provider_uri"]
|
||||
if isinstance(path, Path):
|
||||
path = str(path)
|
||||
is_win = re.match("^[a-zA-Z]:.*", path) is not None # such as 'C:\\data', 'D:'
|
||||
is_nfs_or_win = (
|
||||
re.match("^[^/]+:.+", path) is not None
|
||||
) # such as 'host:/data/' (User may define short hostname by themselves or use localhost)
|
||||
|
||||
if is_nfs_or_win and not is_win:
|
||||
return QlibConfig.NFS_URI
|
||||
else:
|
||||
return QlibConfig.LOCAL_URI
|
||||
|
||||
def set(self, default_conf: str = "client", **kwargs):
|
||||
"""
|
||||
configure qlib based on the input parameters
|
||||
|
||||
@@ -90,13 +90,7 @@ class Alpha360(DataHandlerLP):
|
||||
return (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
|
||||
|
||||
def get_feature_config(self):
|
||||
# NOTE:
|
||||
# Alpha360 tries to provide a dataset with original price data
|
||||
# the original price data includes the prices and volume in the last 60 days.
|
||||
# To make it easier to learn models from this dataset, all the prices and volume
|
||||
# are normalized by the latest price and volume data ( dividing by $close, $volume)
|
||||
# So the latest normalized $close will be 1 (with name CLOSE0), the latest normalized $volume will be 1 (with name VOLUME0)
|
||||
# If further normalization are executed (e.g. centralization), CLOSE0 and VOLUME0 will be 0.
|
||||
|
||||
fields = []
|
||||
names = []
|
||||
|
||||
|
||||
@@ -3,18 +3,15 @@
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from logging import warn
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Union
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..backtest import get_exchange, backtest as backtest_func
|
||||
from ..utils import get_date_range
|
||||
from ..utils.resam import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..backtest import get_exchange, position, backtest as backtest_func, executor as _executor
|
||||
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
@@ -120,129 +117,84 @@ def indicator_analysis(df, method="mean"):
|
||||
|
||||
|
||||
# This is the API for compatibility for legacy code
|
||||
def backtest_daily(
|
||||
start_time: Union[str, pd.Timestamp],
|
||||
end_time: Union[str, pd.Timestamp],
|
||||
strategy: Union[str, dict, BaseStrategy],
|
||||
executor: Union[str, dict, _executor.BaseExecutor] = None,
|
||||
account: Union[float, int, position.Position] = 1e8,
|
||||
benchmark: str = "SH000300",
|
||||
exchange_kwargs: dict = None,
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
"""initialize the strategy and executor, then executor the backtest of daily frequency
|
||||
|
||||
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs):
|
||||
"""This function will help you set a reasonable Exchange and provide default value for strategy
|
||||
Parameters
|
||||
----------
|
||||
start_time : Union[str, pd.Timestamp]
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : Union[str, pd.Timestamp]
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
strategy : Union[str, dict, BaseStrategy]
|
||||
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
|
||||
|
||||
E.g.
|
||||
- **backtest workflow related or commmon arguments**
|
||||
|
||||
.. code-block:: python
|
||||
# dict
|
||||
strategy = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.signal_strategy",
|
||||
"kwargs": {
|
||||
"signal": (model, dataset),
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}
|
||||
# BaseStrategy
|
||||
pred_score = pd.read_pickle("score.pkl")["score"]
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
"signal": pred_score,
|
||||
}
|
||||
strategy = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
# str example.
|
||||
# 1) specify a pickle object
|
||||
# - path like 'file:///<path to pickle file>/obj.pkl'
|
||||
# 2) specify a class name
|
||||
# - "ClassName": getattr(module, "ClassName")() will be used.
|
||||
# 3) specify module path with class name
|
||||
# - "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column.
|
||||
account : float
|
||||
init account value.
|
||||
shift : int
|
||||
whether to shift prediction by one day.
|
||||
benchmark : str
|
||||
benchmark code, default is SH000905 CSI 500.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
- **strategy related arguments**
|
||||
|
||||
executor : Union[str, dict, BaseExecutor]
|
||||
for initializing the outermost executor.
|
||||
benchmark: str
|
||||
the benchmark for reporting.
|
||||
account : Union[float, int, Position]
|
||||
information for describing how to creating the account
|
||||
For `float` or `int`:
|
||||
Using Account with only initial cash
|
||||
For `Position`:
|
||||
Using Account with a Position
|
||||
exchange_kwargs : dict
|
||||
the kwargs for initializing Exchange
|
||||
E.g.
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
.. code-block:: python
|
||||
sell_limit = margin
|
||||
|
||||
exchange_kwargs = {
|
||||
"freq": freq,
|
||||
"limit_threshold": None, # limit_threshold is None, using C.limit_threshold
|
||||
"deal_price": None, # deal_price is None, using C.deal_price
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
- else:
|
||||
|
||||
pos_type : str
|
||||
the type of Position.
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
Returns
|
||||
-------
|
||||
report_normal: pd.DataFrame
|
||||
backtest report
|
||||
positions_normal: pd.DataFrame
|
||||
backtest positions
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
- **exchange related arguments**
|
||||
|
||||
exchange: Exchange()
|
||||
pass the exchange for speeding up.
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost. The default value is 0.002(0.2%).
|
||||
close_cost : float
|
||||
close transaction cost. The default value is 0.002(0.2%).
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
|
||||
- **executor related arguments**
|
||||
|
||||
executor : BaseExecutor()
|
||||
executor used in backtest.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
"""
|
||||
freq = "day"
|
||||
if executor is None:
|
||||
executor_config = {
|
||||
"time_per_step": freq,
|
||||
"generate_portfolio_metrics": True,
|
||||
}
|
||||
executor = _executor.SimulatorExecutor(**executor_config)
|
||||
_exchange_kwargs = {
|
||||
"freq": freq,
|
||||
"limit_threshold": None,
|
||||
"deal_price": None,
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
}
|
||||
if exchange_kwargs is not None:
|
||||
_exchange_kwargs.update(exchange_kwargs)
|
||||
|
||||
portfolio_metric_dict, indicator_dict = backtest_func(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
strategy=strategy,
|
||||
executor=executor,
|
||||
account=account,
|
||||
benchmark=benchmark,
|
||||
exchange_kwargs=_exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
warnings.warn("this function is deprecated, please use backtest function in qlib.backtest", DeprecationWarning)
|
||||
report_dict = backtest_func(
|
||||
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
|
||||
)
|
||||
analysis_freq = "{0}{1}".format(*Freq.parse(freq))
|
||||
|
||||
report_normal, positions_normal = portfolio_metric_dict.get(analysis_freq)
|
||||
|
||||
return report_normal, positions_normal
|
||||
return report_dict.get("report_df"), report_dict.get("positions")
|
||||
|
||||
|
||||
def long_short_backtest(
|
||||
@@ -375,12 +327,7 @@ def t_run():
|
||||
pred["datetime"] = pd.to_datetime(pred["datetime"])
|
||||
pred = pred.set_index([pred.columns[0], pred.columns[1]])
|
||||
pred = pred.iloc[:9000]
|
||||
strategy_config = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
"signal": pred,
|
||||
}
|
||||
report_df, positions = backtest_daily(start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_config)
|
||||
report_df, positions = backtest(pred=pred)
|
||||
print(report_df.head())
|
||||
print(positions.keys())
|
||||
print(positions[list(positions.keys())[0]])
|
||||
|
||||
@@ -30,10 +30,8 @@ try:
|
||||
from .pytorch_nn import DNNModelPytorch
|
||||
from .pytorch_tabnet import TabnetModel
|
||||
from .pytorch_sfm import SFM_Model
|
||||
from .pytorch_tcn import TCN
|
||||
from .pytorch_add import ADD
|
||||
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD)
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
|
||||
@@ -14,12 +14,11 @@ from ...model.interpret.base import LightGBMFInt
|
||||
class LGBModel(ModelFT, LightGBMFInt):
|
||||
"""LightGBM Model"""
|
||||
|
||||
def __init__(self, loss="mse", early_stopping_rounds=50, **kwargs):
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self.params = {"objective": loss, "verbosity": -1}
|
||||
self.params.update(kwargs)
|
||||
self.early_stopping_rounds = early_stopping_rounds
|
||||
self.model = None
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
@@ -45,7 +44,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=None,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
@@ -57,9 +56,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
),
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
|
||||
@@ -1,789 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
import os
|
||||
from pdb import set_trace
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import copy
|
||||
from typing import Text, Union
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.autograd import Function
|
||||
from qlib.contrib.model.pytorch_utils import count_parameters
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import get_or_create_path
|
||||
|
||||
|
||||
class ADARNN(Model):
|
||||
"""ADARNN Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
pre_epoch=40,
|
||||
dw=0.5,
|
||||
loss_type="cosine",
|
||||
len_seq=60,
|
||||
len_win=0,
|
||||
lr=0.001,
|
||||
metric="mse",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
n_splits=2,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADARNN")
|
||||
self.logger.info("ADARNN pytorch version...")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.pre_epoch = pre_epoch
|
||||
self.dw = dw
|
||||
self.loss_type = loss_type
|
||||
self.len_seq = len_seq
|
||||
self.len_win = len_win
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.n_splits = n_splits
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"ADARNN parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
n_hiddens = [hidden_size for _ in range(num_layers)]
|
||||
self.model = AdaRNN(
|
||||
use_bottleneck=False,
|
||||
bottleneck_width=64,
|
||||
n_input=d_feat,
|
||||
n_hiddens=n_hiddens,
|
||||
n_output=1,
|
||||
dropout=dropout,
|
||||
model_type="AdaRNN",
|
||||
len_seq=len_seq,
|
||||
trans_loss=loss_type,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.model.cuda()
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None):
|
||||
self.model.train()
|
||||
criterion = nn.MSELoss()
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
len_loader = np.inf
|
||||
for loader in train_loader_list:
|
||||
if len(loader) < len_loader:
|
||||
len_loader = len(loader)
|
||||
for data_all in zip(*train_loader_list):
|
||||
# for data_all in zip(*train_loader_list):
|
||||
self.train_optimizer.zero_grad()
|
||||
list_feat = []
|
||||
list_label = []
|
||||
for data in data_all:
|
||||
# feature :[36, 24, 6]
|
||||
feature, label_reg = data[0].cuda().float(), data[1].cuda().float()
|
||||
list_feat.append(feature)
|
||||
list_label.append(label_reg)
|
||||
flag = False
|
||||
index = get_index(len(data_all) - 1)
|
||||
for temp_index in index:
|
||||
s1 = temp_index[0]
|
||||
s2 = temp_index[1]
|
||||
if list_feat[s1].shape[0] != list_feat[s2].shape[0]:
|
||||
flag = True
|
||||
break
|
||||
if flag:
|
||||
continue
|
||||
|
||||
total_loss = torch.zeros(1).cuda()
|
||||
for i in range(len(index)):
|
||||
feature_s = list_feat[index[i][0]]
|
||||
feature_t = list_feat[index[i][1]]
|
||||
label_reg_s = list_label[index[i][0]]
|
||||
label_reg_t = list_label[index[i][1]]
|
||||
feature_all = torch.cat((feature_s, feature_t), 0)
|
||||
|
||||
if epoch < self.pre_epoch:
|
||||
pred_all, loss_transfer, out_weight_list = self.model.forward_pre_train(
|
||||
feature_all, len_win=self.len_win
|
||||
)
|
||||
else:
|
||||
pred_all, loss_transfer, dist, weight_mat = self.model.forward_Boosting(feature_all, weight_mat)
|
||||
dist_mat = dist_mat + dist
|
||||
pred_s = pred_all[0 : feature_s.size(0)]
|
||||
pred_t = pred_all[feature_s.size(0) :]
|
||||
|
||||
loss_s = criterion(pred_s, label_reg_s)
|
||||
loss_t = criterion(pred_t, label_reg_t)
|
||||
|
||||
total_loss = total_loss + loss_s + loss_t + self.dw * loss_transfer
|
||||
self.train_optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
if epoch >= self.pre_epoch:
|
||||
if epoch > self.pre_epoch:
|
||||
weight_mat = self.model.update_weight_Boosting(weight_mat, dist_old, dist_mat)
|
||||
return weight_mat, dist_mat
|
||||
else:
|
||||
weight_mat = self.transform_type(out_weight_list)
|
||||
return weight_mat, None
|
||||
|
||||
def calc_all_metrics(self, pred):
|
||||
"""pred is a pandas dataframe that has two attributes: score (pred) and label (real)"""
|
||||
res = {}
|
||||
ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score))
|
||||
rank_ic = pred.groupby(level="datetime").apply(lambda x: x.label.corr(x.score, method="spearman"))
|
||||
res["ic"] = ic.mean()
|
||||
res["icir"] = ic.mean() / ic.std()
|
||||
res["ric"] = rank_ic.mean()
|
||||
res["ricir"] = rank_ic.mean() / rank_ic.std()
|
||||
res["mse"] = -(pred["label"] - pred["score"]).mean()
|
||||
res["loss"] = res["mse"]
|
||||
return res
|
||||
|
||||
def test_epoch(self, df):
|
||||
self.model.eval()
|
||||
preds = self.infer(df["feature"])
|
||||
label = df["label"].squeeze()
|
||||
preds = pd.DataFrame({"label": label, "score": preds}, index=df.index)
|
||||
metrics = self.calc_all_metrics(preds)
|
||||
return metrics
|
||||
|
||||
def log_metrics(self, mode, metrics):
|
||||
metrics = ["{}/{}: {:.6f}".format(k, mode, v) for k, v in metrics.items()]
|
||||
metrics = ", ".join(metrics)
|
||||
self.logger.info(metrics)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
# splits = ['2011-06-30']
|
||||
days = df_train.index.get_level_values(level=0).unique()
|
||||
train_splits = np.array_split(days, self.n_splits)
|
||||
train_splits = [df_train[s[0] : s[-1]] for s in train_splits]
|
||||
train_loader_list = [get_stock_loader(df, self.batch_size) for df in train_splits]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
weight_mat, dist_mat = None, None
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
weight_mat, dist_mat = self.train_AdaRNN(train_loader_list, step, dist_mat, weight_mat)
|
||||
self.logger.info("evaluating...")
|
||||
train_metrics = self.test_epoch(df_train)
|
||||
valid_metrics = self.test_epoch(df_valid)
|
||||
self.log_metrics("train: ", train_metrics)
|
||||
self.log_metrics("valid: ", valid_metrics)
|
||||
|
||||
valid_score = valid_metrics[self.metric]
|
||||
train_score = train_metrics[self.metric]
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(valid_score)
|
||||
if valid_score > best_score:
|
||||
best_score = valid_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
return best_score
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return self.infer(x_test)
|
||||
|
||||
def infer(self, x_test):
|
||||
index = x_test.index
|
||||
self.model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
x_values = x_values.reshape(sample_num, self.d_feat, -1).transpose(0, 2, 1)
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.model.predict(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
def transform_type(self, init_weight):
|
||||
weight = torch.ones(self.num_layers, self.len_seq).cuda()
|
||||
for i in range(self.num_layers):
|
||||
for j in range(self.len_seq):
|
||||
weight[i, j] = init_weight[i][j].item()
|
||||
return weight
|
||||
|
||||
|
||||
class data_loader(Dataset):
|
||||
def __init__(self, df):
|
||||
self.df_feature = df["feature"]
|
||||
self.df_label_reg = df["label"]
|
||||
self.df_index = df.index
|
||||
self.df_feature = torch.tensor(
|
||||
self.df_feature.values.reshape(-1, 6, 60).transpose(0, 2, 1), dtype=torch.float32
|
||||
)
|
||||
self.df_label_reg = torch.tensor(self.df_label_reg.values.reshape(-1), dtype=torch.float32)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample, label_reg = self.df_feature[index], self.df_label_reg[index]
|
||||
return sample, label_reg
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df_feature)
|
||||
|
||||
|
||||
def get_stock_loader(df, batch_size, shuffle=True):
|
||||
train_loader = DataLoader(data_loader(df), batch_size=batch_size, shuffle=shuffle)
|
||||
return train_loader
|
||||
|
||||
|
||||
def get_index(num_domain=2):
|
||||
index = []
|
||||
for i in range(num_domain):
|
||||
for j in range(i + 1, num_domain + 1):
|
||||
index.append((i, j))
|
||||
return index
|
||||
|
||||
|
||||
class AdaRNN(nn.Module):
|
||||
"""
|
||||
model_type: 'Boosting', 'AdaRNN'
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_bottleneck=False,
|
||||
bottleneck_width=256,
|
||||
n_input=128,
|
||||
n_hiddens=[64, 64],
|
||||
n_output=6,
|
||||
dropout=0.0,
|
||||
len_seq=9,
|
||||
model_type="AdaRNN",
|
||||
trans_loss="mmd",
|
||||
):
|
||||
super(AdaRNN, self).__init__()
|
||||
self.use_bottleneck = use_bottleneck
|
||||
self.n_input = n_input
|
||||
self.num_layers = len(n_hiddens)
|
||||
self.hiddens = n_hiddens
|
||||
self.n_output = n_output
|
||||
self.model_type = model_type
|
||||
self.trans_loss = trans_loss
|
||||
self.len_seq = len_seq
|
||||
in_size = self.n_input
|
||||
|
||||
features = nn.ModuleList()
|
||||
for hidden in n_hiddens:
|
||||
rnn = nn.GRU(input_size=in_size, num_layers=1, hidden_size=hidden, batch_first=True, dropout=dropout)
|
||||
features.append(rnn)
|
||||
in_size = hidden
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
if use_bottleneck == True: # finance
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(n_hiddens[-1], bottleneck_width),
|
||||
nn.Linear(bottleneck_width, bottleneck_width),
|
||||
nn.BatchNorm1d(bottleneck_width),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(),
|
||||
)
|
||||
self.bottleneck[0].weight.data.normal_(0, 0.005)
|
||||
self.bottleneck[0].bias.data.fill_(0.1)
|
||||
self.bottleneck[1].weight.data.normal_(0, 0.005)
|
||||
self.bottleneck[1].bias.data.fill_(0.1)
|
||||
self.fc = nn.Linear(bottleneck_width, n_output)
|
||||
torch.nn.init.xavier_normal_(self.fc.weight)
|
||||
else:
|
||||
self.fc_out = nn.Linear(n_hiddens[-1], self.n_output)
|
||||
|
||||
if self.model_type == "AdaRNN":
|
||||
gate = nn.ModuleList()
|
||||
for i in range(len(n_hiddens)):
|
||||
gate_weight = nn.Linear(len_seq * self.hiddens[i] * 2, len_seq)
|
||||
gate.append(gate_weight)
|
||||
self.gate = gate
|
||||
|
||||
bnlst = nn.ModuleList()
|
||||
for i in range(len(n_hiddens)):
|
||||
bnlst.append(nn.BatchNorm1d(len_seq))
|
||||
self.bn_lst = bnlst
|
||||
self.softmax = torch.nn.Softmax(dim=0)
|
||||
self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
for i in range(len(self.hiddens)):
|
||||
self.gate[i].weight.data.normal_(0, 0.05)
|
||||
self.gate[i].bias.data.fill_(0.0)
|
||||
|
||||
def forward_pre_train(self, x, len_win=0):
|
||||
out = self.gru_features(x)
|
||||
fea = out[0] # [2N,L,H]
|
||||
if self.use_bottleneck == True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
fc_out = self.fc_out(fea[:, -1, :]).squeeze() # [N,]
|
||||
|
||||
out_list_all, out_weight_list = out[1], out[2]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
h_start = 0
|
||||
for j in range(h_start, self.len_seq, 1):
|
||||
i_start = j - len_win if j - len_win >= 0 else 0
|
||||
i_end = j + len_win if j + len_win < self.len_seq else self.len_seq - 1
|
||||
for k in range(i_start, i_end + 1):
|
||||
weight = (
|
||||
out_weight_list[i][j]
|
||||
if self.model_type == "AdaRNN"
|
||||
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
|
||||
)
|
||||
loss_transfer = loss_transfer + weight * criterion_transder.compute(
|
||||
out_list_s[i][:, j, :], out_list_t[i][:, k, :]
|
||||
)
|
||||
return fc_out, loss_transfer, out_weight_list
|
||||
|
||||
def gru_features(self, x, predict=False):
|
||||
x_input = x
|
||||
out = None
|
||||
out_lis = []
|
||||
out_weight_list = [] if (self.model_type == "AdaRNN") else None
|
||||
for i in range(self.num_layers):
|
||||
out, _ = self.features[i](x_input.float())
|
||||
x_input = out
|
||||
out_lis.append(out)
|
||||
if self.model_type == "AdaRNN" and predict == False:
|
||||
out_gate = self.process_gate_weight(x_input, i)
|
||||
out_weight_list.append(out_gate)
|
||||
return out, out_lis, out_weight_list
|
||||
|
||||
def process_gate_weight(self, out, index):
|
||||
x_s = out[0 : int(out.shape[0] // 2)]
|
||||
x_t = out[out.shape[0] // 2 : out.shape[0]]
|
||||
x_all = torch.cat((x_s, x_t), 2)
|
||||
x_all = x_all.view(x_all.shape[0], -1)
|
||||
weight = torch.sigmoid(self.bn_lst[index](self.gate[index](x_all.float())))
|
||||
weight = torch.mean(weight, dim=0)
|
||||
res = self.softmax(weight).squeeze()
|
||||
return res
|
||||
|
||||
def get_features(self, output_list):
|
||||
fea_list_src, fea_list_tar = [], []
|
||||
for fea in output_list:
|
||||
fea_list_src.append(fea[0 : fea.size(0) // 2])
|
||||
fea_list_tar.append(fea[fea.size(0) // 2 :])
|
||||
return fea_list_src, fea_list_tar
|
||||
|
||||
# For Boosting-based
|
||||
def forward_Boosting(self, x, weight_mat=None):
|
||||
out = self.gru_features(x)
|
||||
fea = out[0]
|
||||
if self.use_bottleneck:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
fc_out = self.fc_out(fea[:, -1, :]).squeeze()
|
||||
|
||||
out_list_all = out[1]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
if weight_mat is None:
|
||||
weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).cuda()
|
||||
else:
|
||||
weight = weight_mat
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
for j in range(self.len_seq):
|
||||
loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :])
|
||||
loss_transfer = loss_transfer + weight[i, j] * loss_trans
|
||||
dist_mat[i, j] = loss_trans
|
||||
return fc_out, loss_transfer, dist_mat, weight
|
||||
|
||||
# For Boosting-based
|
||||
def update_weight_Boosting(self, weight_mat, dist_old, dist_new):
|
||||
epsilon = 1e-5
|
||||
dist_old = dist_old.detach()
|
||||
dist_new = dist_new.detach()
|
||||
ind = dist_new > dist_old + epsilon
|
||||
weight_mat[ind] = weight_mat[ind] * (1 + torch.sigmoid(dist_new[ind] - dist_old[ind]))
|
||||
weight_norm = torch.norm(weight_mat, dim=1, p=1)
|
||||
weight_mat = weight_mat / weight_norm.t().unsqueeze(1).repeat(1, self.len_seq)
|
||||
return weight_mat
|
||||
|
||||
def predict(self, x):
|
||||
out = self.gru_features(x, predict=True)
|
||||
fea = out[0]
|
||||
if self.use_bottleneck == True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
fc_out = self.fc_out(fea[:, -1, :]).squeeze()
|
||||
return fc_out
|
||||
|
||||
|
||||
class TransferLoss(object):
|
||||
def __init__(self, loss_type="cosine", input_dim=512):
|
||||
"""
|
||||
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
||||
"""
|
||||
self.loss_type = loss_type
|
||||
self.input_dim = input_dim
|
||||
|
||||
def compute(self, X, Y):
|
||||
"""Compute adaptation loss
|
||||
|
||||
Arguments:
|
||||
X {tensor} -- source matrix
|
||||
Y {tensor} -- target matrix
|
||||
|
||||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
if self.loss_type == "mmd_lin" or self.loss_type == "mmd":
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "coral":
|
||||
loss = CORAL(X, Y)
|
||||
elif self.loss_type == "cosine" or self.loss_type == "cos":
|
||||
loss = 1 - cosine(X, Y)
|
||||
elif self.loss_type == "kl":
|
||||
loss = kl_div(X, Y)
|
||||
elif self.loss_type == "js":
|
||||
loss = js(X, Y)
|
||||
elif self.loss_type == "mine":
|
||||
mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).cuda()
|
||||
loss = mine_model(X, Y)
|
||||
elif self.loss_type == "adv":
|
||||
loss = adv(X, Y, input_dim=self.input_dim, hidden_dim=32)
|
||||
elif self.loss_type == "mmd_rbf":
|
||||
mmdloss = MMD_loss(kernel_type="rbf")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "pairwise":
|
||||
pair_mat = pairwise_dist(X, Y)
|
||||
loss = torch.norm(pair_mat)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def cosine(source, target):
|
||||
source, target = source.mean(), target.mean()
|
||||
cos = nn.CosineSimilarity(dim=0)
|
||||
loss = cos(source, target)
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class ReverseLayerF(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, alpha):
|
||||
ctx.alpha = alpha
|
||||
return x.view_as(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
output = grad_output.neg() * ctx.alpha
|
||||
return output, None
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, input_dim=256, hidden_dim=256):
|
||||
super(Discriminator, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.dis1 = nn.Linear(input_dim, hidden_dim)
|
||||
self.dis2 = nn.Linear(hidden_dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.dis1(x))
|
||||
x = self.dis2(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
def adv(source, target, input_dim=256, hidden_dim=512):
|
||||
domain_loss = nn.BCELoss()
|
||||
# !!! Pay attention to .cuda !!!
|
||||
adv_net = Discriminator(input_dim, hidden_dim).cuda()
|
||||
domain_src = torch.ones(len(source)).cuda()
|
||||
domain_tar = torch.zeros(len(target)).cuda()
|
||||
domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1)
|
||||
reverse_src = ReverseLayerF.apply(source, 1)
|
||||
reverse_tar = ReverseLayerF.apply(target, 1)
|
||||
pred_src = adv_net(reverse_src)
|
||||
pred_tar = adv_net(reverse_tar)
|
||||
loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar)
|
||||
loss = loss_s + loss_t
|
||||
return loss
|
||||
|
||||
|
||||
def CORAL(source, target):
|
||||
d = source.size(1)
|
||||
ns, nt = source.size(0), target.size(0)
|
||||
|
||||
# source covariance
|
||||
tmp_s = torch.ones((1, ns)).cuda() @ source
|
||||
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
|
||||
|
||||
# target covariance
|
||||
tmp_t = torch.ones((1, nt)).cuda() @ target
|
||||
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
|
||||
|
||||
# frobenius norm
|
||||
loss = (cs - ct).pow(2).sum()
|
||||
loss = loss / (4 * d * d)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class MMD_loss(nn.Module):
|
||||
def __init__(self, kernel_type="linear", kernel_mul=2.0, kernel_num=5):
|
||||
super(MMD_loss, self).__init__()
|
||||
self.kernel_num = kernel_num
|
||||
self.kernel_mul = kernel_mul
|
||||
self.fix_sigma = None
|
||||
self.kernel_type = kernel_type
|
||||
|
||||
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
|
||||
n_samples = int(source.size()[0]) + int(target.size()[0])
|
||||
total = torch.cat([source, target], dim=0)
|
||||
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
|
||||
L2_distance = ((total0 - total1) ** 2).sum(2)
|
||||
if fix_sigma:
|
||||
bandwidth = fix_sigma
|
||||
else:
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
|
||||
bandwidth /= kernel_mul ** (kernel_num // 2)
|
||||
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
|
||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||
return sum(kernel_val)
|
||||
|
||||
def linear_mmd(self, X, Y):
|
||||
delta = X.mean(axis=0) - Y.mean(axis=0)
|
||||
loss = delta.dot(delta.T)
|
||||
return loss
|
||||
|
||||
def forward(self, source, target):
|
||||
if self.kernel_type == "linear":
|
||||
return self.linear_mmd(source, target)
|
||||
elif self.kernel_type == "rbf":
|
||||
batch_size = int(source.size()[0])
|
||||
kernels = self.guassian_kernel(
|
||||
source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma
|
||||
)
|
||||
with torch.no_grad():
|
||||
XX = torch.mean(kernels[:batch_size, :batch_size])
|
||||
YY = torch.mean(kernels[batch_size:, batch_size:])
|
||||
XY = torch.mean(kernels[:batch_size, batch_size:])
|
||||
YX = torch.mean(kernels[batch_size:, :batch_size])
|
||||
loss = torch.mean(XX + YY - XY - YX)
|
||||
return loss
|
||||
|
||||
|
||||
class Mine_estimator(nn.Module):
|
||||
def __init__(self, input_dim=2048, hidden_dim=512):
|
||||
super(Mine_estimator, self).__init__()
|
||||
self.mine_model = Mine(input_dim, hidden_dim)
|
||||
|
||||
def forward(self, X, Y):
|
||||
Y_shffle = Y[torch.randperm(len(Y))]
|
||||
loss_joint = self.mine_model(X, Y)
|
||||
loss_marginal = self.mine_model(X, Y_shffle)
|
||||
ret = torch.mean(loss_joint) - torch.log(torch.mean(torch.exp(loss_marginal)))
|
||||
loss = -ret
|
||||
return loss
|
||||
|
||||
|
||||
class Mine(nn.Module):
|
||||
def __init__(self, input_dim=2048, hidden_dim=512):
|
||||
super(Mine, self).__init__()
|
||||
self.fc1_x = nn.Linear(input_dim, hidden_dim)
|
||||
self.fc1_y = nn.Linear(input_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, 1)
|
||||
|
||||
def forward(self, x, y):
|
||||
h1 = F.leaky_relu(self.fc1_x(x) + self.fc1_y(y))
|
||||
h2 = self.fc2(h1)
|
||||
return h2
|
||||
|
||||
|
||||
def pairwise_dist(X, Y):
|
||||
n, d = X.shape
|
||||
m, _ = Y.shape
|
||||
assert d == Y.shape[1]
|
||||
a = X.unsqueeze(1).expand(n, m, d)
|
||||
b = Y.unsqueeze(0).expand(n, m, d)
|
||||
return torch.pow(a - b, 2).sum(2)
|
||||
|
||||
|
||||
def pairwise_dist_np(X, Y):
|
||||
n, d = X.shape
|
||||
m, _ = Y.shape
|
||||
assert d == Y.shape[1]
|
||||
a = np.expand_dims(X, 1)
|
||||
b = np.expand_dims(Y, 0)
|
||||
a = np.tile(a, (1, m, 1))
|
||||
b = np.tile(b, (n, 1, 1))
|
||||
return np.power(a - b, 2).sum(2)
|
||||
|
||||
|
||||
def pa(X, Y):
|
||||
XY = np.dot(X, Y.T)
|
||||
XX = np.sum(np.square(X), axis=1)
|
||||
XX = np.transpose([XX])
|
||||
YY = np.sum(np.square(Y), axis=1)
|
||||
dist = XX + YY - 2 * XY
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def kl_div(source, target):
|
||||
if len(source) < len(target):
|
||||
target = target[: len(source)]
|
||||
elif len(source) > len(target):
|
||||
source = source[: len(target)]
|
||||
criterion = nn.KLDivLoss(reduction="batchmean")
|
||||
loss = criterion(source.log(), target)
|
||||
return loss
|
||||
|
||||
|
||||
def js(source, target):
|
||||
if len(source) < len(target):
|
||||
target = target[: len(source)]
|
||||
elif len(source) > len(target):
|
||||
source = source[: len(target)]
|
||||
M = 0.5 * (source + target)
|
||||
loss_1, loss_2 = kl_div(source, M), kl_div(target, M)
|
||||
return 0.5 * (loss_1 + loss_2)
|
||||
@@ -1,598 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Text, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from qlib.contrib.model.pytorch_gru import GRUModel
|
||||
from qlib.contrib.model.pytorch_lstm import LSTMModel
|
||||
from qlib.contrib.model.pytorch_utils import count_parameters
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.processor import CSRankNorm
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import get_or_create_path
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class ADD(Model):
|
||||
"""ADD Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lr : float
|
||||
learning rate
|
||||
d_feat : int
|
||||
input dimensions for each time step
|
||||
metric : str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
dec_dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="mse",
|
||||
batch_size=5000,
|
||||
early_stop=20,
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
gamma=0.1,
|
||||
gamma_clip=0.4,
|
||||
mu=0.05,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("ADD")
|
||||
self.logger.info("ADD pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.dec_dropout = dec_dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.base_model = base_model
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.gamma = gamma
|
||||
self.gamma_clip = gamma_clip
|
||||
self.mu = mu
|
||||
|
||||
self.logger.info(
|
||||
"ADD parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\ndec_dropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nbase_model : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\ngamma : {}"
|
||||
"\ngamma_clip : {}"
|
||||
"\nmu : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
dec_dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
base_model,
|
||||
model_path,
|
||||
gamma,
|
||||
gamma_clip,
|
||||
mu,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.ADD_model = ADDModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
dec_dropout=self.dec_dropout,
|
||||
base_model=self.base_model,
|
||||
gamma=self.gamma,
|
||||
gamma_clip=self.gamma_clip,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ADD_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ADD_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ADD_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.ADD_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.ADD_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def loss_pre_excess(self, pred_excess, label_excess, record=None):
|
||||
mask = ~torch.isnan(label_excess)
|
||||
pre_excess_loss = F.mse_loss(pred_excess[mask], label_excess[mask])
|
||||
if record is not None:
|
||||
record["pre_excess_loss"] = pre_excess_loss.item()
|
||||
return pre_excess_loss
|
||||
|
||||
def loss_pre_market(self, pred_market, label_market, record=None):
|
||||
pre_market_loss = F.cross_entropy(pred_market, label_market)
|
||||
if record is not None:
|
||||
record["pre_market_loss"] = pre_market_loss.item()
|
||||
return pre_market_loss
|
||||
|
||||
def loss_pre(self, pred_excess, label_excess, pred_market, label_market, record=None):
|
||||
pre_loss = self.loss_pre_excess(pred_excess, label_excess, record) + self.loss_pre_market(
|
||||
pred_market, label_market, record
|
||||
)
|
||||
if record is not None:
|
||||
record["pre_loss"] = pre_loss.item()
|
||||
return pre_loss
|
||||
|
||||
def loss_adv_excess(self, adv_excess, label_excess, record=None):
|
||||
mask = ~torch.isnan(label_excess)
|
||||
adv_excess_loss = F.mse_loss(adv_excess.squeeze()[mask], label_excess[mask])
|
||||
if record is not None:
|
||||
record["adv_excess_loss"] = adv_excess_loss.item()
|
||||
return adv_excess_loss
|
||||
|
||||
def loss_adv_market(self, adv_market, label_market, record=None):
|
||||
adv_market_loss = F.cross_entropy(adv_market, label_market)
|
||||
if record is not None:
|
||||
record["adv_market_loss"] = adv_market_loss.item()
|
||||
return adv_market_loss
|
||||
|
||||
def loss_adv(self, adv_excess, label_excess, adv_market, label_market, record=None):
|
||||
adv_loss = self.loss_adv_excess(adv_excess, label_excess, record) + self.loss_adv_market(
|
||||
adv_market, label_market, record
|
||||
)
|
||||
if record is not None:
|
||||
record["adv_loss"] = adv_loss.item()
|
||||
return adv_loss
|
||||
|
||||
def loss_fn(self, x, preds, label_excess, label_market, record=None):
|
||||
loss = (
|
||||
self.loss_pre(preds["excess"], label_excess, preds["market"], label_market, record)
|
||||
+ self.loss_adv(preds["adv_excess"], label_excess, preds["adv_market"], label_market, record)
|
||||
+ self.mu * self.loss_rec(x, preds["reconstructed_feature"], record)
|
||||
)
|
||||
if record is not None:
|
||||
record["loss"] = loss.item()
|
||||
return loss
|
||||
|
||||
def loss_rec(self, x, rec_x, record=None):
|
||||
x = x.reshape(len(x), self.d_feat, -1)
|
||||
x = x.permute(0, 2, 1)
|
||||
rec_loss = F.mse_loss(x, rec_x)
|
||||
if record is not None:
|
||||
record["rec_loss"] = rec_loss.item()
|
||||
return rec_loss
|
||||
|
||||
def get_daily_inter(self, df, shuffle=False):
|
||||
# organize the train data into daily batches
|
||||
daily_count = df.groupby(level=0).size().values
|
||||
daily_index = np.roll(np.cumsum(daily_count), 1)
|
||||
daily_index[0] = 0
|
||||
if shuffle:
|
||||
# shuffle data
|
||||
daily_shuffle = list(zip(daily_index, daily_count))
|
||||
np.random.shuffle(daily_shuffle)
|
||||
daily_index, daily_count = zip(*daily_shuffle)
|
||||
return daily_index, daily_count
|
||||
|
||||
def cal_ic_metrics(self, pred, label):
|
||||
metrics = {}
|
||||
metrics["mse"] = -F.mse_loss(pred, label).item()
|
||||
metrics["loss"] = metrics["mse"]
|
||||
pred = pd.Series(pred.cpu().detach().numpy())
|
||||
label = pd.Series(label.cpu().detach().numpy())
|
||||
metrics["ic"] = pred.corr(label)
|
||||
metrics["ric"] = pred.corr(label, method="spearman")
|
||||
return metrics
|
||||
|
||||
def test_epoch(self, data_x, data_y, data_m):
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
m_values = np.squeeze(data_m.values.astype(int))
|
||||
self.ADD_model.eval()
|
||||
|
||||
metrics_list = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(data_x, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
feature = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
label_excess = torch.from_numpy(y_values[batch]).float().to(self.device)
|
||||
label_market = torch.from_numpy(m_values[batch]).long().to(self.device)
|
||||
|
||||
metrics = {}
|
||||
preds = self.ADD_model(feature)
|
||||
self.loss_fn(feature, preds, label_excess, label_market, metrics)
|
||||
metrics.update(self.cal_ic_metrics(preds["excess"], label_excess))
|
||||
metrics_list.append(metrics)
|
||||
metrics = {}
|
||||
keys = metrics_list[0].keys()
|
||||
for k in keys:
|
||||
vs = [m[k] for m in metrics_list]
|
||||
metrics[k] = sum(vs) / len(vs)
|
||||
|
||||
return metrics
|
||||
|
||||
def train_epoch(self, x_train_values, y_train_values, m_train_values):
|
||||
self.ADD_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
cur_step = 1
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
batch = indices[i : i + self.batch_size]
|
||||
feature = torch.from_numpy(x_train_values[batch]).float().to(self.device)
|
||||
label_excess = torch.from_numpy(y_train_values[batch]).float().to(self.device)
|
||||
label_market = torch.from_numpy(m_train_values[batch]).long().to(self.device)
|
||||
|
||||
preds = self.ADD_model(feature)
|
||||
|
||||
loss = self.loss_fn(feature, preds, label_excess, label_market)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.ADD_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
cur_step += 1
|
||||
|
||||
def log_metrics(self, mode, metrics):
|
||||
metrics = ["{}/{}: {:.6f}".format(k, mode, v) for k, v in metrics.items()]
|
||||
metrics = ", ".join(metrics)
|
||||
self.logger.info(metrics)
|
||||
|
||||
def bootstrap_fit(self, x_train, y_train, m_train, x_valid, y_valid, m_valid):
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
m_train_values = np.squeeze(m_train.values.astype(int))
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train_values, y_train_values, m_train_values)
|
||||
self.logger.info("evaluating...")
|
||||
train_metrics = self.test_epoch(x_train, y_train, m_train)
|
||||
valid_metrics = self.test_epoch(x_valid, y_valid, m_valid)
|
||||
self.log_metrics("train", train_metrics)
|
||||
self.log_metrics("valid", valid_metrics)
|
||||
|
||||
if self.metric in valid_metrics:
|
||||
val_score = valid_metrics[self.metric]
|
||||
else:
|
||||
raise ValueError("unknown metric name `%s`" % self.metric)
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.ADD_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
self.ADD_model.before_adv_excess.step_alpha()
|
||||
self.ADD_model.before_adv_market.step_alpha()
|
||||
self.logger.info("bootstrap_fit best score: {:.6f} @ {}".format(best_score, best_epoch))
|
||||
self.ADD_model.load_state_dict(best_param)
|
||||
return best_score
|
||||
|
||||
def gen_market_label(self, df, raw_label):
|
||||
market_label = raw_label.groupby("datetime").mean().squeeze()
|
||||
bins = [-np.inf, self.lo, self.hi, np.inf]
|
||||
market_label = pd.cut(market_label, bins, labels=False)
|
||||
market_label.name = ("market_return", "market_return")
|
||||
df = df.join(market_label)
|
||||
return df
|
||||
|
||||
def fit_thresh(self, train_label):
|
||||
market_label = train_label.groupby("datetime").mean().squeeze()
|
||||
self.lo, self.hi = market_label.quantile([1 / 3, 2 / 3])
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
label_train, label_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["label"],
|
||||
data_key=DataHandlerLP.DK_R,
|
||||
)
|
||||
self.fit_thresh(label_train)
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
df_train = self.gen_market_label(df_train, label_train)
|
||||
df_valid = self.gen_market_label(df_valid, label_valid)
|
||||
|
||||
x_train, y_train, m_train = df_train["feature"], df_train["label"], df_train["market_return"]
|
||||
x_valid, y_valid, m_valid = df_valid["feature"], df_valid["label"], df_valid["market_return"]
|
||||
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
# load pretrained base_model
|
||||
|
||||
if self.base_model == "LSTM":
|
||||
pretrained_model = LSTMModel()
|
||||
elif self.base_model == "GRU":
|
||||
pretrained_model = GRUModel()
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % self.base_model)
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.ADD_model.enc_excess.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.ADD_model.enc_excess.load_state_dict(model_dict)
|
||||
model_dict = self.ADD_model.enc_market.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.ADD_model.enc_market.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
self.bootstrap_fit(x_train, y_train, m_train, x_valid, y_valid, m_valid)
|
||||
|
||||
best_param = copy.deepcopy(self.ADD_model.state_dict())
|
||||
save_path = get_or_create_path(save_path)
|
||||
torch.save(best_param, save_path)
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.ADD_model.eval()
|
||||
x_values = x_test.values
|
||||
preds = []
|
||||
|
||||
daily_index, daily_count = self.get_daily_inter(x_test, shuffle=False)
|
||||
|
||||
for idx, count in zip(daily_index, daily_count):
|
||||
batch = slice(idx, idx + count)
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.ADD_model(x_batch)
|
||||
pred = pred["excess"].detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
r = pd.Series(np.concatenate(preds), index=index)
|
||||
return r
|
||||
|
||||
|
||||
class ADDModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
dec_dropout=0.5,
|
||||
base_model="GRU",
|
||||
gamma=0.1,
|
||||
gamma_clip=0.4,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_feat = d_feat
|
||||
self.base_model = base_model
|
||||
if base_model == "GRU":
|
||||
self.enc_excess, self.enc_market = [
|
||||
nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
elif base_model == "LSTM":
|
||||
self.enc_excess, self.enc_market = [
|
||||
nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
self.dec = Decoder(d_feat, 2 * hidden_size, num_layers, dec_dropout, base_model)
|
||||
|
||||
ctx_size = hidden_size * num_layers
|
||||
self.pred_excess, self.adv_excess = [
|
||||
nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 1))
|
||||
for _ in range(2)
|
||||
]
|
||||
self.adv_market, self.pred_market = [
|
||||
nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 3))
|
||||
for _ in range(2)
|
||||
]
|
||||
self.before_adv_market, self.before_adv_excess = [RevGrad(gamma, gamma_clip) for _ in range(2)]
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(len(x), self.d_feat, -1)
|
||||
N = x.shape[0]
|
||||
T = x.shape[-1]
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
out, hidden_excess = self.enc_excess(x)
|
||||
out, hidden_market = self.enc_market(x)
|
||||
if self.base_model == "LSTM":
|
||||
feature_excess = hidden_excess[0].permute(1, 0, 2).reshape(N, -1)
|
||||
feature_market = hidden_market[0].permute(1, 0, 2).reshape(N, -1)
|
||||
else:
|
||||
feature_excess = hidden_excess.permute(1, 0, 2).reshape(N, -1)
|
||||
feature_market = hidden_market.permute(1, 0, 2).reshape(N, -1)
|
||||
predicts = {}
|
||||
predicts["excess"] = self.pred_excess(feature_excess).squeeze(1)
|
||||
predicts["market"] = self.pred_market(feature_market)
|
||||
predicts["adv_market"] = self.adv_market(self.before_adv_market(feature_excess))
|
||||
predicts["adv_excess"] = self.adv_excess(self.before_adv_excess(feature_market).squeeze(1))
|
||||
if self.base_model == "LSTM":
|
||||
hidden = [torch.cat([hidden_excess[i], hidden_market[i]], -1) for i in range(2)]
|
||||
else:
|
||||
hidden = torch.cat([hidden_excess, hidden_market], -1)
|
||||
x = torch.zeros_like(x[:, 1, :])
|
||||
reconstructed_feature = []
|
||||
for i in range(T):
|
||||
x, hidden = self.dec(x, hidden)
|
||||
reconstructed_feature.append(x)
|
||||
reconstructed_feature = torch.stack(reconstructed_feature, 1)
|
||||
predicts["reconstructed_feature"] = reconstructed_feature
|
||||
return predicts
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=128, num_layers=1, dropout=0.5, base_model="GRU"):
|
||||
super().__init__()
|
||||
self.base_model = base_model
|
||||
if base_model == "GRU":
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
elif base_model == "LSTM":
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown base model name `%s`" % base_model)
|
||||
|
||||
self.fc = nn.Linear(hidden_size, d_feat)
|
||||
|
||||
def forward(self, x, hidden):
|
||||
x = x.unsqueeze(1)
|
||||
output, hidden = self.rnn(x, hidden)
|
||||
output = output.squeeze(1)
|
||||
pred = self.fc(output)
|
||||
return pred, hidden
|
||||
|
||||
|
||||
class RevGradFunc(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, alpha_):
|
||||
ctx.save_for_backward(input_, alpha_)
|
||||
output = input_
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pragma: no cover
|
||||
grad_input = None
|
||||
_, alpha_ = ctx.saved_tensors
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = -grad_output * alpha_
|
||||
return grad_input, None
|
||||
|
||||
|
||||
class RevGrad(nn.Module):
|
||||
def __init__(self, gamma=0.1, gamma_clip=0.4, *args, **kwargs):
|
||||
"""
|
||||
A gradient reversal layer.
|
||||
This layer has no parameters, and simply reverses the gradient
|
||||
in the backward pass.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.gamma = gamma
|
||||
self.gamma_clip = torch.tensor(float(gamma_clip), requires_grad=False)
|
||||
self._alpha = torch.tensor(0, requires_grad=False)
|
||||
self._p = 0
|
||||
|
||||
def step_alpha(self):
|
||||
self._p += 1
|
||||
self._alpha = min(
|
||||
self.gamma_clip, torch.tensor(2 / (1 + math.exp(-self.gamma * self._p)) - 1, requires_grad=False)
|
||||
)
|
||||
|
||||
def forward(self, input_):
|
||||
return RevGradFunc.apply(input_, self._alpha)
|
||||
@@ -73,7 +73,7 @@ class GATs(Model):
|
||||
base_model="GRU",
|
||||
model_path=None,
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
GPU="0",
|
||||
n_jobs=10,
|
||||
seed=None,
|
||||
**kwargs
|
||||
|
||||
@@ -267,7 +267,7 @@ class DNNModelPytorch(Model):
|
||||
loss = torch.mul(sqr_loss, w).mean()
|
||||
return loss
|
||||
elif loss_type == "binary":
|
||||
loss = nn.BCEWithLogitsLoss(weight=w)
|
||||
loss = nn.BCELoss(weight=w)
|
||||
return loss(pred, target)
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
@@ -334,8 +334,16 @@ class Net(nn.Module):
|
||||
dnn_layers.append(seq)
|
||||
drop_input = nn.Dropout(0.05)
|
||||
dnn_layers.append(drop_input)
|
||||
fc = nn.Linear(hidden_units, output_dim)
|
||||
dnn_layers.append(fc)
|
||||
if loss == "mse":
|
||||
fc = nn.Linear(hidden_units, output_dim)
|
||||
dnn_layers.append(fc)
|
||||
|
||||
elif loss == "binary":
|
||||
fc = nn.Linear(hidden_units, output_dim)
|
||||
sigmoid = nn.Sigmoid()
|
||||
dnn_layers.append(nn.Sequential(fc, sigmoid))
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
# optimizer
|
||||
self.dnn_layers = nn.ModuleList(dnn_layers)
|
||||
self._weight_init()
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from .tcn import TemporalConvNet
|
||||
|
||||
|
||||
class TCN(Model):
|
||||
"""TCN Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
n_chans: int
|
||||
number of channels
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
n_chans=128,
|
||||
kernel_size=5,
|
||||
num_layers=5,
|
||||
dropout=0.5,
|
||||
n_epochs=200,
|
||||
lr=0.0001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCN")
|
||||
self.logger.info("TCN pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.n_chans = n_chans
|
||||
self.kernel_size = kernel_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"TCN parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nn_chans : {}"
|
||||
"\nkernel_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
n_chans,
|
||||
kernel_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.tcn_model = TCNModel(
|
||||
num_input=self.d_feat,
|
||||
output_size=1,
|
||||
num_channels=[self.n_chans] * self.num_layers,
|
||||
kernel_size=self.kernel_size,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.tcn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.tcn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.tcn_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.tcn_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.tcn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
self.tcn_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.tcn_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.tcn_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.tcn_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.tcn_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
val_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.tcn_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.tcn_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.tcn_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.tcn_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class TCNModel(nn.Module):
|
||||
def __init__(self, num_input, output_size, num_channels, kernel_size, dropout):
|
||||
super().__init__()
|
||||
self.num_input = num_input
|
||||
self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout)
|
||||
self.linear = nn.Linear(num_channels[-1], output_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], self.num_input, -1)
|
||||
output = self.tcn(x)
|
||||
output = self.linear(output[:, :, -1])
|
||||
return output.squeeze()
|
||||
@@ -1,300 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from .tcn import TemporalConvNet
|
||||
|
||||
|
||||
class TCN(Model):
|
||||
"""TCN Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
n_chans=128,
|
||||
kernel_size=5,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
optimizer="adam",
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCN")
|
||||
self.logger.info("TCN pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.n_chans = n_chans
|
||||
self.kernel_size = kernel_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
"TCN parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nn_chans : {}"
|
||||
"\nkernel_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
n_chans,
|
||||
kernel_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.TCN_model = TCNModel(
|
||||
num_input=self.d_feat,
|
||||
output_size=1,
|
||||
num_channels=[self.n_chans] * self.num_layers,
|
||||
kernel_size=self.kernel_size,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.TCN_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.TCN_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.TCN_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.TCN_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.TCN_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def train_epoch(self, data_loader):
|
||||
|
||||
self.TCN_model.train()
|
||||
|
||||
for data in data_loader:
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.TCN_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.TCN_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_loader):
|
||||
|
||||
self.TCN_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
for data in data_loader:
|
||||
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.TCN_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
# process nan brought by dataloader
|
||||
dl_train.config(fillna_type="ffill+bfill")
|
||||
# process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill")
|
||||
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_loader)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(train_loader)
|
||||
val_loss, val_score = self.test_epoch(valid_loader)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.TCN_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.TCN_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.TCN_model.eval()
|
||||
preds = []
|
||||
|
||||
for data in test_loader:
|
||||
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.TCN_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=dl_test.get_index())
|
||||
|
||||
|
||||
class TCNModel(nn.Module):
|
||||
def __init__(self, num_input, output_size, num_channels, kernel_size, dropout):
|
||||
super().__init__()
|
||||
self.num_input = num_input
|
||||
self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout)
|
||||
self.linear = nn.Linear(num_channels[-1], output_size)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.tcn(x)
|
||||
output = self.linear(output[:, :, -1])
|
||||
return output.squeeze()
|
||||
@@ -1,77 +0,0 @@
|
||||
# MIT License
|
||||
# Copyright (c) 2018 CMU Locus Lab
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
class Chomp1d(nn.Module):
|
||||
def __init__(self, chomp_size):
|
||||
super(Chomp1d, self).__init__()
|
||||
self.chomp_size = chomp_size
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, :, : -self.chomp_size].contiguous()
|
||||
|
||||
|
||||
class TemporalBlock(nn.Module):
|
||||
def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
|
||||
super(TemporalBlock, self).__init__()
|
||||
self.conv1 = weight_norm(
|
||||
nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
|
||||
)
|
||||
self.chomp1 = Chomp1d(padding)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
|
||||
self.conv2 = weight_norm(
|
||||
nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
|
||||
)
|
||||
self.chomp2 = Chomp1d(padding)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
self.conv1, self.chomp1, self.relu1, self.dropout1, self.conv2, self.chomp2, self.relu2, self.dropout2
|
||||
)
|
||||
self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
|
||||
self.relu = nn.ReLU()
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
self.conv1.weight.data.normal_(0, 0.01)
|
||||
self.conv2.weight.data.normal_(0, 0.01)
|
||||
if self.downsample is not None:
|
||||
self.downsample.weight.data.normal_(0, 0.01)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.net(x)
|
||||
res = x if self.downsample is None else self.downsample(x)
|
||||
return self.relu(out + res)
|
||||
|
||||
|
||||
class TemporalConvNet(nn.Module):
|
||||
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
|
||||
super(TemporalConvNet, self).__init__()
|
||||
layers = []
|
||||
num_levels = len(num_channels)
|
||||
for i in range(num_levels):
|
||||
dilation_size = 2 ** i
|
||||
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
||||
out_channels = num_channels[i]
|
||||
layers += [
|
||||
TemporalBlock(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=dilation_size,
|
||||
padding=(kernel_size - 1) * dilation_size,
|
||||
dropout=dropout,
|
||||
)
|
||||
]
|
||||
|
||||
self.network = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.network(x)
|
||||
@@ -1,14 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import yaml
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
import shutil
|
||||
from ...backtest.account import Account
|
||||
from ..backtest.account import Account
|
||||
from ..backtest.exchange import Exchange
|
||||
from .user import User
|
||||
from .utils import load_instance, save_instance
|
||||
from ...utils import init_instance_by_config
|
||||
from .utils import load_instance
|
||||
from ...utils import save_instance, init_instance_by_config
|
||||
|
||||
|
||||
class UserManager:
|
||||
|
||||
@@ -6,10 +6,10 @@ import pickle
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ...log import get_module_logger
|
||||
from ...utils import get_module_by_module_path, init_instance_by_config
|
||||
from ...utils import get_next_trading_date
|
||||
from ...backtest.exchange import Exchange
|
||||
from ..backtest.exchange import Exchange
|
||||
|
||||
log = get_module_logger("utils")
|
||||
|
||||
@@ -42,7 +42,7 @@ def save_instance(instance, file_path):
|
||||
"""
|
||||
file_path = pathlib.Path(file_path)
|
||||
with file_path.open("wb") as fr:
|
||||
pickle.dump(instance, fr, C.dump_protocol_version)
|
||||
pickle.dump(instance, fr)
|
||||
|
||||
|
||||
def create_user_folder(path):
|
||||
|
||||
@@ -57,7 +57,7 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
|
||||
).figure
|
||||
|
||||
t_df = t_df.loc[:, ["long-short", "long-average"]]
|
||||
_bin_size = float(((t_df.max() - t_df.min()) / 20).min())
|
||||
_bin_size = ((t_df.max() - t_df.min()) / 20).min()
|
||||
group_hist_figure = SubplotsGraph(
|
||||
t_df,
|
||||
kind_map=dict(kind="DistplotGraph", kwargs=dict(bin_size=_bin_size)),
|
||||
|
||||
@@ -171,55 +171,20 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.utils.time import Freq
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.backtest import backtest, executor
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.evaluate import backtest
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
|
||||
# init qlib
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
# backtest parameters
|
||||
bparas = {}
|
||||
bparas['limit_threshold'] = 0.095
|
||||
bparas['account'] = 1000000000
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
FREQ = "day"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
# pred_score, pd.Series
|
||||
"signal": pred_score,
|
||||
}
|
||||
sparas = {}
|
||||
sparas['topk'] = 50
|
||||
sparas['n_drop'] = 230
|
||||
strategy = TopkDropoutStrategy(**sparas)
|
||||
|
||||
EXECUTOR_CONFIG = {
|
||||
"time_per_step": "day",
|
||||
"generate_portfolio_metrics": True,
|
||||
}
|
||||
|
||||
backtest_config = {
|
||||
"start_time": "2017-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"account": 100000000,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"exchange_kwargs": {
|
||||
"freq": FREQ,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# strategy object
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
# executor object
|
||||
executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG)
|
||||
# backtest
|
||||
portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config)
|
||||
analysis_freq = "{0}{1}".format(*Freq.parse(FREQ))
|
||||
# backtest info
|
||||
report_normal_df, positions_normal = portfolio_metric_dict.get(analysis_freq)
|
||||
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
|
||||
|
||||
qcr.analysis_position.report_graph(report_normal_df)
|
||||
|
||||
|
||||
@@ -170,64 +170,32 @@ def risk_analysis_graph(
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.utils.time import Freq
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.backtest import backtest, executor
|
||||
from qlib.contrib.evaluate import risk_analysis
|
||||
from qlib.contrib.evaluate import risk_analysis, backtest, long_short_backtest
|
||||
from qlib.contrib.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.report import analysis_position
|
||||
|
||||
# init qlib
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
# backtest parameters
|
||||
bparas = {}
|
||||
bparas['limit_threshold'] = 0.095
|
||||
bparas['account'] = 1000000000
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
FREQ = "day"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
# pred_score, pd.Series
|
||||
"signal": pred_score,
|
||||
}
|
||||
sparas = {}
|
||||
sparas['topk'] = 50
|
||||
sparas['n_drop'] = 230
|
||||
strategy = TopkDropoutStrategy(**sparas)
|
||||
|
||||
EXECUTOR_CONFIG = {
|
||||
"time_per_step": "day",
|
||||
"generate_portfolio_metrics": True,
|
||||
}
|
||||
report_normal_df, positions = backtest(pred_df, strategy, **bparas)
|
||||
# long_short_map = long_short_backtest(pred_df)
|
||||
# report_long_short_df = pd.DataFrame(long_short_map)
|
||||
|
||||
backtest_config = {
|
||||
"start_time": "2017-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"account": 100000000,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"exchange_kwargs": {
|
||||
"freq": FREQ,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
# strategy object
|
||||
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
|
||||
# executor object
|
||||
executor_obj = executor.SimulatorExecutor(**EXECUTOR_CONFIG)
|
||||
# backtest
|
||||
portfolio_metric_dict, indicator_dict = backtest(executor=executor_obj, strategy=strategy_obj, **backtest_config)
|
||||
analysis_freq = "{0}{1}".format(*Freq.parse(FREQ))
|
||||
# backtest info
|
||||
report_normal_df, positions_normal = portfolio_metric_dict.get(analysis_freq)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal_df["return"] - report_normal_df["bench"], freq=analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal_df["return"] - report_normal_df["bench"] - report_normal_df["cost"], freq=analysis_freq
|
||||
)
|
||||
# analysis['pred_long'] = risk_analysis(report_long_short_df['long'])
|
||||
# analysis['pred_short'] = risk_analysis(report_long_short_df['short'])
|
||||
# analysis['pred_long_short'] = risk_analysis(report_long_short_df['long_short'])
|
||||
analysis['excess_return_without_cost'] = risk_analysis(report_normal_df['return'] - report_normal_df['bench'])
|
||||
analysis['excess_return_with_cost'] = risk_analysis(report_normal_df['return'] - report_normal_df['bench'] - report_normal_df['cost'])
|
||||
analysis_df = pd.concat(analysis)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
analysis_position.risk_analysis_graph(analysis_df, report_normal_df)
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ class Tuner:
|
||||
space=self.space,
|
||||
algo=tpe.suggest,
|
||||
max_evals=self.max_evals,
|
||||
show_progressbar=False,
|
||||
)
|
||||
self.logger.info("Local best params: {} ".format(self.best_params))
|
||||
TimeInspector.log_cost_time(
|
||||
|
||||
@@ -8,8 +8,6 @@ from __future__ import print_function
|
||||
import abc
|
||||
import pandas as pd
|
||||
|
||||
from ..log import get_module_logger
|
||||
|
||||
|
||||
class Expression(abc.ABC):
|
||||
"""Expression base class"""
|
||||
@@ -152,15 +150,7 @@ class Expression(abc.ABC):
|
||||
return H["f"][args]
|
||||
if start_index is None or end_index is None or start_index > end_index:
|
||||
raise ValueError("Invalid index range: {} {}".format(start_index, end_index))
|
||||
try:
|
||||
series = self._load_internal(instrument, start_index, end_index, freq)
|
||||
except Exception as e:
|
||||
get_module_logger("data").debug(
|
||||
f"Loading data error: instrument={instrument}, expression={str(self)}, "
|
||||
f"start_index={start_index}, end_index={end_index}, freq={freq}. "
|
||||
f"error info: {str(e)}"
|
||||
)
|
||||
raise
|
||||
series = self._load_internal(instrument, start_index, end_index, freq)
|
||||
series.name = str(self)
|
||||
H["f"][args] = series
|
||||
return series
|
||||
|
||||
@@ -230,7 +230,7 @@ class CacheUtils:
|
||||
d["meta"]["visits"] = d["meta"]["visits"] + 1
|
||||
except KeyError:
|
||||
raise KeyError("Unknown meta keyword")
|
||||
pickle.dump(d, f, protocol=C.dump_protocol_version)
|
||||
pickle.dump(d, f)
|
||||
except Exception as e:
|
||||
get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}")
|
||||
|
||||
@@ -573,7 +573,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
meta_path = cache_path.with_suffix(".meta")
|
||||
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(meta, f, protocol=C.dump_protocol_version)
|
||||
pickle.dump(meta, f)
|
||||
meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
df = expression_data.to_frame()
|
||||
|
||||
@@ -638,7 +638,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
# update meta file
|
||||
d["info"]["last_update"] = str(new_calendar[-1])
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(d, f, protocol=C.dump_protocol_version)
|
||||
pickle.dump(d, f)
|
||||
return 0
|
||||
|
||||
|
||||
@@ -927,7 +927,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
meta = {
|
||||
"info": {
|
||||
"instruments": instruments,
|
||||
"fields": list(cache_features.columns),
|
||||
"fields": cache_columns,
|
||||
"freq": freq,
|
||||
"last_update": str(_calendar[-1]), # The last_update to store the cache
|
||||
"inst_processors": inst_processors, # The last_update to store the cache
|
||||
@@ -935,7 +935,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
"meta": {"last_visit": time.time(), "visits": 1},
|
||||
}
|
||||
with cache_path.with_suffix(".meta").open("wb") as f:
|
||||
pickle.dump(meta, f, protocol=C.dump_protocol_version)
|
||||
pickle.dump(meta, f)
|
||||
cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
|
||||
# write index file
|
||||
im = DiskDatasetCache.IndexManager(cache_path)
|
||||
@@ -965,7 +965,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
fields = d["info"]["fields"]
|
||||
freq = d["info"]["freq"]
|
||||
last_update_time = d["info"]["last_update"]
|
||||
inst_processors = d["info"].get("inst_processors", [])
|
||||
inst_processors = d["info"]["inst_processors"]
|
||||
index_data = im.get_index()
|
||||
|
||||
self.logger.debug("Updating dataset: {}".format(d))
|
||||
@@ -1057,7 +1057,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
# update meta file
|
||||
d["info"]["last_update"] = str(new_calendar[-1])
|
||||
with meta_path.open("wb") as f:
|
||||
pickle.dump(d, f, protocol=C.dump_protocol_version)
|
||||
pickle.dump(d, f)
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
@@ -1,103 +1,102 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import socketio
|
||||
|
||||
import qlib
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
import pickle
|
||||
|
||||
|
||||
class Client:
|
||||
"""A client class
|
||||
|
||||
Provide the connection tool functions for ClientProvider.
|
||||
"""
|
||||
|
||||
def __init__(self, host, port):
|
||||
super(Client, self).__init__()
|
||||
self.sio = socketio.Client()
|
||||
self.server_host = host
|
||||
self.server_port = port
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
# bind connect/disconnect callbacks
|
||||
self.sio.on(
|
||||
"connect",
|
||||
lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)),
|
||||
)
|
||||
self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!"))
|
||||
|
||||
def connect_server(self):
|
||||
"""Connect to server."""
|
||||
try:
|
||||
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
|
||||
except socketio.exceptions.ConnectionError:
|
||||
self.logger.error("Cannot connect to server - check your network or server status")
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from server."""
|
||||
try:
|
||||
self.sio.eio.disconnect(True)
|
||||
except Exception as e:
|
||||
self.logger.error("Cannot disconnect from server : %s" % e)
|
||||
|
||||
def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None):
|
||||
"""Send a certain request to server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
request_type : str
|
||||
type of proposed request, 'calendar'/'instrument'/'feature'.
|
||||
request_content : dict
|
||||
records the information of the request.
|
||||
msg_proc_func : func
|
||||
the function to process the message when receiving response, should have arg `*args`.
|
||||
msg_queue: Queue
|
||||
The queue to pass the messsage after callback.
|
||||
"""
|
||||
head_info = {"version": qlib.__version__}
|
||||
|
||||
def request_callback(*args):
|
||||
"""callback_wrapper
|
||||
|
||||
:param *args: args[0] is the response content
|
||||
"""
|
||||
# args[0] is the response content
|
||||
self.logger.debug("receive data and enter queue")
|
||||
msg = dict(args[0])
|
||||
if msg["detailed_info"] is not None:
|
||||
if msg["status"] != 0:
|
||||
self.logger.error(msg["detailed_info"])
|
||||
else:
|
||||
self.logger.info(msg["detailed_info"])
|
||||
if msg["status"] != 0:
|
||||
ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}")
|
||||
msg_queue.put(ex)
|
||||
else:
|
||||
if msg_proc_func is not None:
|
||||
try:
|
||||
ret = msg_proc_func(msg["result"])
|
||||
except Exception as e:
|
||||
self.logger.exception("Error when processing message.")
|
||||
ret = e
|
||||
else:
|
||||
ret = msg["result"]
|
||||
msg_queue.put(ret)
|
||||
self.disconnect()
|
||||
self.logger.debug("disconnected")
|
||||
|
||||
self.logger.debug("try connecting")
|
||||
self.connect_server()
|
||||
self.logger.debug("connected")
|
||||
# The pickle is for passing some parameters with special type(such as
|
||||
# pd.Timestamp)
|
||||
request_content = {"head": head_info, "body": pickle.dumps(request_content, protocol=C.dump_protocol_version)}
|
||||
self.sio.on(request_type + "_response", request_callback)
|
||||
self.logger.debug("try sending")
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
self.sio.wait()
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import socketio
|
||||
|
||||
import qlib
|
||||
from ..log import get_module_logger
|
||||
import pickle
|
||||
|
||||
|
||||
class Client:
|
||||
"""A client class
|
||||
|
||||
Provide the connection tool functions for ClientProvider.
|
||||
"""
|
||||
|
||||
def __init__(self, host, port):
|
||||
super(Client, self).__init__()
|
||||
self.sio = socketio.Client()
|
||||
self.server_host = host
|
||||
self.server_port = port
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
# bind connect/disconnect callbacks
|
||||
self.sio.on(
|
||||
"connect",
|
||||
lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)),
|
||||
)
|
||||
self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!"))
|
||||
|
||||
def connect_server(self):
|
||||
"""Connect to server."""
|
||||
try:
|
||||
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
|
||||
except socketio.exceptions.ConnectionError:
|
||||
self.logger.error("Cannot connect to server - check your network or server status")
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from server."""
|
||||
try:
|
||||
self.sio.eio.disconnect(True)
|
||||
except Exception as e:
|
||||
self.logger.error("Cannot disconnect from server : %s" % e)
|
||||
|
||||
def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None):
|
||||
"""Send a certain request to server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
request_type : str
|
||||
type of proposed request, 'calendar'/'instrument'/'feature'.
|
||||
request_content : dict
|
||||
records the information of the request.
|
||||
msg_proc_func : func
|
||||
the function to process the message when receiving response, should have arg `*args`.
|
||||
msg_queue: Queue
|
||||
The queue to pass the messsage after callback.
|
||||
"""
|
||||
head_info = {"version": qlib.__version__}
|
||||
|
||||
def request_callback(*args):
|
||||
"""callback_wrapper
|
||||
|
||||
:param *args: args[0] is the response content
|
||||
"""
|
||||
# args[0] is the response content
|
||||
self.logger.debug("receive data and enter queue")
|
||||
msg = dict(args[0])
|
||||
if msg["detailed_info"] is not None:
|
||||
if msg["status"] != 0:
|
||||
self.logger.error(msg["detailed_info"])
|
||||
else:
|
||||
self.logger.info(msg["detailed_info"])
|
||||
if msg["status"] != 0:
|
||||
ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}")
|
||||
msg_queue.put(ex)
|
||||
else:
|
||||
if msg_proc_func is not None:
|
||||
try:
|
||||
ret = msg_proc_func(msg["result"])
|
||||
except Exception as e:
|
||||
self.logger.exception("Error when processing message.")
|
||||
ret = e
|
||||
else:
|
||||
ret = msg["result"]
|
||||
msg_queue.put(ret)
|
||||
self.disconnect()
|
||||
self.logger.debug("disconnected")
|
||||
|
||||
self.logger.debug("try connecting")
|
||||
self.connect_server()
|
||||
self.logger.debug("connected")
|
||||
# The pickle is for passing some parameters with special type(such as
|
||||
# pd.Timestamp)
|
||||
request_content = {"head": head_info, "body": pickle.dumps(request_content)}
|
||||
self.sio.on(request_type + "_response", request_callback)
|
||||
self.logger.debug("try sending")
|
||||
self.sio.emit(request_type + "_request", request_content)
|
||||
self.sio.wait()
|
||||
|
||||
@@ -27,6 +27,7 @@ from .inst_processor import InstProcessor
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import Freq
|
||||
from ..utils.resam import resam_calendar
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
@@ -37,7 +38,6 @@ from ..utils import (
|
||||
hash_args,
|
||||
normalize_cache_fields,
|
||||
code_to_fname,
|
||||
set_log_with_config,
|
||||
)
|
||||
from ..utils.paral import ParallelExt
|
||||
|
||||
@@ -57,13 +57,34 @@ class ProviderBackendMixin:
|
||||
backend = copy.deepcopy(backend)
|
||||
|
||||
# set default storage kwargs
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. qlib.init: provider_uri
|
||||
backend_kwargs = backend.setdefault("kwargs", {})
|
||||
provider_uri = backend_kwargs.get("provider_uri", None)
|
||||
provider_uri = C.dpm.provider_uri if provider_uri is None else C.dpm.format_provider_uri(provider_uri)
|
||||
backend_kwargs["provider_uri"] = provider_uri
|
||||
# default provider_uri map
|
||||
if "provider_uri" not in backend_kwargs:
|
||||
# if the user has no uri configured, use: uri = uri_map[freq]
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
|
||||
# 3. qlib.init: provider_uri
|
||||
provider_uri_map = backend_kwargs.setdefault("provider_uri_map", {})
|
||||
freq = kwargs.get("freq", "day")
|
||||
if freq not in provider_uri_map:
|
||||
# NOTE: uri
|
||||
# 1. If `freq` in C.dpm.provider_uri.keys(), uri = C.dpm.provider_uri[freq]
|
||||
# 2. If `freq` not in C.dpm.provider_uri.keys()
|
||||
# - Get the `min_freq` closest to `freq` from C.dpm.provider_uri.keys(), uri = C.dpm.provider_uri[min_freq]
|
||||
# NOTE: In Storage, only CalendarStorage is supported
|
||||
# 1. If `uri` does not exist
|
||||
# - Get the `min_uri` of the closest `freq` under the same "directory" as the `uri`
|
||||
# - Read data from `min_uri` and resample to `freq`
|
||||
try:
|
||||
_uri = C.dpm.get_data_uri(freq)
|
||||
except KeyError:
|
||||
# provider_uri is dict and freq not in list(provider_uri.keys())
|
||||
# use the nearest freq greater than 0
|
||||
min_freq = Freq.get_recent_freq(freq, C.dpm.provider_uri.keys())
|
||||
_uri = C.dpm.get_data_uri(freq) if min_freq is None else C.dpm.get_data_uri(min_freq)
|
||||
provider_uri_map[freq] = _uri
|
||||
backend_kwargs["provider_uri"] = provider_uri_map[freq]
|
||||
backend.setdefault("kwargs", {}).update(**kwargs)
|
||||
return init_instance_by_config(backend)
|
||||
|
||||
@@ -565,8 +586,6 @@ class DatasetProvider(abc.ABC):
|
||||
# NOTE: This place is compatible with windows, windows multi-process is spawn
|
||||
if not C.registered:
|
||||
C.set_conf_from_C(g_config)
|
||||
if C.logging_config:
|
||||
set_log_with_config(C.logging_config)
|
||||
C.register()
|
||||
|
||||
obj = dict()
|
||||
@@ -705,15 +724,7 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
end_time = pd.Timestamp(end_time)
|
||||
_, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False)
|
||||
lft_etd, rght_etd = expression.get_extended_window_size()
|
||||
try:
|
||||
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
|
||||
except Exception as e:
|
||||
get_module_logger("data").debug(
|
||||
f"Loading expression error: "
|
||||
f"instrument={instrument}, field=({field}), start_time={start_time}, end_time={end_time}, freq={freq}. "
|
||||
f"error info: {str(e)}"
|
||||
)
|
||||
raise
|
||||
series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq)
|
||||
# Ensure that each column type is consistent
|
||||
# FIXME:
|
||||
# 1) The stock data is currently float. If there is other types of data, this part needs to be re-implemented.
|
||||
|
||||
@@ -197,8 +197,6 @@ class Fillna(Processor):
|
||||
|
||||
class MinMaxNorm(Processor):
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None):
|
||||
# NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!
|
||||
# `fit_end_time` **must not** include any information from the test data!!!
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
@@ -228,8 +226,6 @@ class ZScoreNorm(Processor):
|
||||
"""ZScore Normalization"""
|
||||
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None):
|
||||
# NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!
|
||||
# `fit_end_time` **must not** include any information from the test data!!!
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
@@ -267,8 +263,6 @@ class RobustZScoreNorm(Processor):
|
||||
"""
|
||||
|
||||
def __init__(self, fit_start_time, fit_end_time, fields_group=None, clip_outlier=True):
|
||||
# NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!!
|
||||
# `fit_end_time` **must not** include any information from the test data!!!
|
||||
self.fit_start_time = fit_start_time
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
@@ -308,13 +302,7 @@ class CSZScoreNorm(Processor):
|
||||
|
||||
|
||||
class CSRankNorm(Processor):
|
||||
"""
|
||||
Cross Sectional Rank Normalization.
|
||||
"Cross Sectional" is often used to describe data operations.
|
||||
The operations across different stocks are often called Cross Sectional Operation.
|
||||
|
||||
For example, CSRankNorm is an operation that grouping the data by each day and rank `across` all the stocks in each day.
|
||||
"""
|
||||
"""Cross Sectional Rank Normalization"""
|
||||
|
||||
def __init__(self, fields_group=None):
|
||||
self.fields_group = fields_group
|
||||
|
||||
@@ -14,8 +14,6 @@ from typing import Union, List, Type
|
||||
from scipy.stats import percentileofscore
|
||||
|
||||
from .base import Expression, ExpressionOps, Feature
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_callable_kwargs
|
||||
|
||||
@@ -307,29 +305,7 @@ class NpPairOperator(PairOperator):
|
||||
series_right = self.feature_right.load(instrument, start_index, end_index, freq)
|
||||
else:
|
||||
series_right = self.feature_right
|
||||
check_length = isinstance(series_left, (np.ndarray, pd.Series)) and isinstance(
|
||||
series_right, (np.ndarray, pd.Series)
|
||||
)
|
||||
if check_length:
|
||||
warning_info = (
|
||||
f"Loading {instrument}: {str(self)}; np.{self.func}(series_left, series_right), "
|
||||
f"The length of series_left and series_right is different: ({len(series_left)}, {len(series_right)}), "
|
||||
f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_right)}. Please check the data"
|
||||
)
|
||||
else:
|
||||
warning_info = (
|
||||
f"Loading {instrument}: {str(self)}; np.{self.func}(series_left, series_right), "
|
||||
f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_right)}. Please check the data"
|
||||
)
|
||||
try:
|
||||
res = getattr(np, self.func)(series_left, series_right)
|
||||
except ValueError as e:
|
||||
get_module_logger("ops").debug(warning_info)
|
||||
raise ValueError(f"{str(e)}. \n\t{warning_info}")
|
||||
else:
|
||||
if check_length and len(series_left) != len(series_right):
|
||||
get_module_logger("ops").debug(warning_info)
|
||||
return res
|
||||
return getattr(np, self.func)(series_left, series_right)
|
||||
|
||||
|
||||
class Add(NpPairOperator):
|
||||
@@ -689,9 +665,6 @@ class If(ExpressionOps):
|
||||
|
||||
class Rolling(ExpressionOps):
|
||||
"""Rolling Operator
|
||||
The meaning of rolling and expanding is the same in pandas.
|
||||
When the window is set to 0, the behaviour of the operator should follow `expanding`
|
||||
Otherwise, it follows `rolling`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -1197,14 +1170,6 @@ class Delta(Rolling):
|
||||
# support pair-wise rolling like `Slope(A, B, N)`
|
||||
class Slope(Rolling):
|
||||
"""Rolling Slope
|
||||
This operator calculate the slope between `idx` and `feature`.
|
||||
(e.g. [<feature_t1>, <feature_t2>, <feature_t3>] and [1, 2, 3])
|
||||
|
||||
Usage Example:
|
||||
- "Slope($close, %d)/$close"
|
||||
|
||||
# TODO:
|
||||
# Some users may want pair-wise rolling like `Slope(A, B, N)`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -10,45 +10,23 @@ import pandas as pd
|
||||
|
||||
from qlib.utils.time import Freq
|
||||
from qlib.utils.resam import resam_calendar
|
||||
from qlib.config import C
|
||||
from qlib.data.cache import H
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
|
||||
from qlib.data.cache import H
|
||||
|
||||
logger = get_module_logger("file_storage")
|
||||
|
||||
|
||||
class FileStorageMixin:
|
||||
"""FileStorageMixin, applicable to FileXXXStorage
|
||||
Subclasses need to have provider_uri, freq, storage_name, file_name attributes
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def dpm(self):
|
||||
return C.DataPathManager(self.provider_uri, None)
|
||||
|
||||
@property
|
||||
def support_freq(self) -> List[str]:
|
||||
_v = "_support_freq"
|
||||
if hasattr(self, _v):
|
||||
return getattr(self, _v)
|
||||
if len(self.provider_uri) == 1 and C.DEFAULT_FREQ in self.provider_uri:
|
||||
freq_l = filter(
|
||||
lambda _freq: not _freq.endswith("_future"),
|
||||
map(lambda x: x.stem, self.dpm.get_data_uri(C.DEFAULT_FREQ).joinpath("calendars").glob("*.txt")),
|
||||
)
|
||||
else:
|
||||
freq_l = self.provider_uri.keys()
|
||||
freq_l = [Freq(freq) for freq in freq_l]
|
||||
setattr(self, _v, freq_l)
|
||||
return freq_l
|
||||
|
||||
@property
|
||||
def uri(self) -> Path:
|
||||
if self.freq not in self.support_freq:
|
||||
raise ValueError(f"{self.storage_name}: {self.provider_uri} does not contain data for {self.freq}")
|
||||
return self.dpm.get_data_uri(self.freq).joinpath(f"{self.storage_name}s", self.file_name)
|
||||
_provider_uri = self.kwargs.get("provider_uri", None)
|
||||
if _provider_uri is None:
|
||||
raise ValueError(
|
||||
f"The `provider_uri` parameter is not found in {self.__class__.__name__}, "
|
||||
f'please specify `provider_uri` in the "provider\'s backend"'
|
||||
)
|
||||
return Path(_provider_uri).expanduser().joinpath(f"{self.storage_name}s", self.file_name)
|
||||
|
||||
def check(self):
|
||||
"""check self.uri
|
||||
@@ -62,32 +40,10 @@ class FileStorageMixin:
|
||||
|
||||
|
||||
class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
def __init__(self, freq: str, future: bool, provider_uri: dict, **kwargs):
|
||||
def __init__(self, freq: str, future: bool, **kwargs):
|
||||
super(FileCalendarStorage, self).__init__(freq, future, **kwargs)
|
||||
self.future = future
|
||||
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self.enable_read_cache = True # TODO: make it configurable
|
||||
|
||||
@property
|
||||
def file_name(self) -> str:
|
||||
return f"{self._freq_file}_future.txt" if self.future else f"{self._freq_file}.txt".lower()
|
||||
|
||||
@property
|
||||
def _freq_file(self) -> str:
|
||||
"""the freq to read from file"""
|
||||
if not hasattr(self, "_freq_file_cache"):
|
||||
freq = Freq(self.freq)
|
||||
if freq not in self.support_freq:
|
||||
# NOTE: uri
|
||||
# 1. If `uri` does not exist
|
||||
# - Get the `min_uri` of the closest `freq` under the same "directory" as the `uri`
|
||||
# - Read data from `min_uri` and resample to `freq`
|
||||
|
||||
freq = Freq.get_recent_freq(freq, self.support_freq)
|
||||
if freq is None:
|
||||
raise ValueError(f"can't find a freq from {self.support_freq} that can resample to {self.freq}!")
|
||||
self._freq_file_cache = freq
|
||||
return self._freq_file_cache
|
||||
self.file_name = f"{freq}_future.txt" if future else f"{freq}.txt".lower()
|
||||
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
|
||||
if not self.uri.exists():
|
||||
@@ -102,23 +58,29 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
||||
with self.uri.open(mode=mode) as fp:
|
||||
np.savetxt(fp, values, fmt="%s", encoding="utf-8")
|
||||
|
||||
@property
|
||||
def uri(self) -> Path:
|
||||
return self.dpm.get_data_uri(self._freq_file).joinpath(f"{self.storage_name}s", self.file_name)
|
||||
|
||||
@property
|
||||
def data(self) -> List[CalVT]:
|
||||
self.check()
|
||||
# If cache is enabled, then return cache directly
|
||||
if self.enable_read_cache:
|
||||
key = "orig_file" + str(self.uri)
|
||||
if not key in H["c"]:
|
||||
H["c"][key] = self._read_calendar()
|
||||
_calendar = H["c"][key]
|
||||
else:
|
||||
# NOTE: uri
|
||||
# 1. If `uri` does not exist
|
||||
# - Get the `min_uri` of the closest `freq` under the same "directory" as the `uri`
|
||||
# - Read data from `min_uri` and resample to `freq`
|
||||
try:
|
||||
self.check()
|
||||
_calendar = self._read_calendar()
|
||||
if Freq(self._freq_file) != Freq(self.freq):
|
||||
_calendar = resam_calendar(np.array(list(map(pd.Timestamp, _calendar))), self._freq_file, self.freq)
|
||||
except ValueError:
|
||||
freq_list = self._get_storage_freq()
|
||||
_freq = Freq.get_recent_freq(self.freq, freq_list)
|
||||
if _freq is None:
|
||||
raise ValueError(f"can't find a freq from {freq_list} that can resample to {self.freq}!")
|
||||
self.file_name = f"{_freq}_future.txt" if self.future else f"{_freq}.txt".lower()
|
||||
# The cache is useful for the following cases
|
||||
# - multiple frequencies are sampled from the same calendar
|
||||
cache_key = self.uri
|
||||
if cache_key not in H["c"]:
|
||||
H["c"][cache_key] = self._read_calendar()
|
||||
_calendar = H["c"][cache_key]
|
||||
_calendar = resam_calendar(np.array(list(map(pd.Timestamp, _calendar))), _freq, self.freq)
|
||||
|
||||
return _calendar
|
||||
|
||||
def _get_storage_freq(self) -> List[str]:
|
||||
@@ -173,9 +135,8 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
|
||||
INSTRUMENT_END_FIELD = "end_datetime"
|
||||
SYMBOL_FIELD_NAME = "instrument"
|
||||
|
||||
def __init__(self, market: str, freq: str, provider_uri: dict, **kwargs):
|
||||
super(FileInstrumentStorage, self).__init__(market, freq, **kwargs)
|
||||
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
|
||||
def __init__(self, market: str, **kwargs):
|
||||
super(FileInstrumentStorage, self).__init__(market, **kwargs)
|
||||
self.file_name = f"{market.lower()}.txt"
|
||||
|
||||
def _read_instrument(self) -> Dict[InstKT, InstVT]:
|
||||
@@ -262,9 +223,8 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage):
|
||||
|
||||
|
||||
class FileFeatureStorage(FileStorageMixin, FeatureStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict, **kwargs):
|
||||
def __init__(self, instrument: str, field: str, freq: str, **kwargs):
|
||||
super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs)
|
||||
self.provider_uri = C.DataPathManager.format_provider_uri(provider_uri)
|
||||
self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin"
|
||||
|
||||
def clear(self):
|
||||
|
||||
@@ -195,9 +195,8 @@ class CalendarStorage(BaseStorage):
|
||||
|
||||
|
||||
class InstrumentStorage(BaseStorage):
|
||||
def __init__(self, market: str, freq: str, **kwargs):
|
||||
def __init__(self, market: str, **kwargs):
|
||||
self.market = market
|
||||
self.freq = freq
|
||||
self.kwargs = kwargs
|
||||
|
||||
@property
|
||||
|
||||
@@ -16,7 +16,6 @@ import time
|
||||
import re
|
||||
from typing import Callable, List
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
@@ -26,48 +25,6 @@ from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
# from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
|
||||
def _log_task_info(task_config: dict):
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.set_tags(**{"hostname": socket.gethostname()})
|
||||
|
||||
|
||||
def _exe_task(task_config: dict):
|
||||
rec = R.get_recorder()
|
||||
# model & dataset initiation
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
# FIXME: resume reweighter after merging data selection
|
||||
# reweighter: Reweighter = task_config.get("reweighter", None)
|
||||
# model training
|
||||
# auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
# this dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
# fill placehorder
|
||||
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
|
||||
task_config = fill_placeholder(task_config, placehorder_value)
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
# Some recorder require the parameter `model` and `dataset`.
|
||||
# try to automatically pass in them to the initialization function
|
||||
# to make defining the tasking easier
|
||||
r = init_instance_by_config(
|
||||
record,
|
||||
recorder=rec,
|
||||
default_module="qlib.workflow.record_temp",
|
||||
try_kwargs={"model": model, "dataset": dataset},
|
||||
)
|
||||
r.generate()
|
||||
|
||||
|
||||
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
|
||||
"""
|
||||
@@ -82,65 +39,17 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
|
||||
Recorder: the model recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
|
||||
_log_task_info(task_config)
|
||||
return R.get_recorder()
|
||||
|
||||
|
||||
def get_item_from_obj(config: dict, name_path: str) -> object:
|
||||
"""
|
||||
Follow the name_path to get values from config
|
||||
For example:
|
||||
If we follow the example in in the Parameters section,
|
||||
Timestamp('2008-01-02 00:00:00') will be returned
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : dict
|
||||
e.g.
|
||||
{'dataset': {'class': 'DatasetH',
|
||||
'kwargs': {'handler': {'class': 'Alpha158',
|
||||
'kwargs': {'end_time': '2020-08-01',
|
||||
'fit_end_time': '<dataset.kwargs.segments.train.1>',
|
||||
'fit_start_time': '<dataset.kwargs.segments.train.0>',
|
||||
'instruments': 'csi100',
|
||||
'start_time': '2008-01-01'},
|
||||
'module_path': 'qlib.contrib.data.handler'},
|
||||
'segments': {'test': (Timestamp('2017-01-03 00:00:00'),
|
||||
Timestamp('2019-04-08 00:00:00')),
|
||||
'train': (Timestamp('2008-01-02 00:00:00'),
|
||||
Timestamp('2014-12-31 00:00:00')),
|
||||
'valid': (Timestamp('2015-01-05 00:00:00'),
|
||||
Timestamp('2016-12-30 00:00:00'))}}
|
||||
}}
|
||||
name_path : str
|
||||
e.g.
|
||||
"dataset.kwargs.segments.train.1"
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
the retrieved object
|
||||
"""
|
||||
cur_cfg = config
|
||||
for k in name_path.split("."):
|
||||
if isinstance(cur_cfg, dict):
|
||||
cur_cfg = cur_cfg[k]
|
||||
elif k.isdigit():
|
||||
cur_cfg = cur_cfg[int(k)]
|
||||
else:
|
||||
raise ValueError(f"Error when getting {k} from cur_cfg")
|
||||
return cur_cfg
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.set_tags(**{"hostname": socket.gethostname()})
|
||||
recorder: Recorder = R.get_recorder()
|
||||
return recorder
|
||||
|
||||
|
||||
def fill_placeholder(config: dict, config_extend: dict):
|
||||
"""
|
||||
Detect placeholder in config and fill them with config_extend.
|
||||
The item of dict must be single item(int, str, etc), dict and list. Tuples are not supported.
|
||||
There are two type of variables:
|
||||
- user-defined variables :
|
||||
e.g. when config_extend is `{"<MODEL>": model, "<DATASET>": dataset}`, "<MODEL>" and "<DATASET>" in `config` will be replaced with `model` `dataset`
|
||||
- variables extracted from `config` :
|
||||
e.g. the variables like "<dataset.kwargs.segments.train.0>" will be replaced with the values from `config`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -173,13 +82,8 @@ def fill_placeholder(config: dict, config_extend: dict):
|
||||
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
|
||||
item_queue.append(now_item[key])
|
||||
tail += 1
|
||||
elif isinstance(now_item[key], str):
|
||||
if now_item[key] in config_extend.keys():
|
||||
now_item[key] = config_extend[now_item[key]]
|
||||
else:
|
||||
m = re.match(r"<(?P<name_path>[^<>]+)>", now_item[key])
|
||||
if m is not None:
|
||||
now_item[key] = get_item_from_obj(config, m.groupdict()["name_path"])
|
||||
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
|
||||
now_item[key] = config_extend[now_item[key]]
|
||||
return config
|
||||
|
||||
|
||||
@@ -196,11 +100,38 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
|
||||
task_config = R.load_object("task")
|
||||
_exe_task(task_config)
|
||||
# model & dataset initiation
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
# model training
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
# this dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
# fill placehorder
|
||||
placehorder_value = {"<MODEL>": model, "<DATASET>": dataset}
|
||||
task_config = fill_placeholder(task_config, placehorder_value)
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
if isinstance(records, dict): # uniform the data format to list
|
||||
records = [records]
|
||||
|
||||
for record in records:
|
||||
# Some recorder require the parameter `model` and `dataset`.
|
||||
# try to automatically pass in them to the initialization function
|
||||
# to make defining the tasking easier
|
||||
r = init_instance_by_config(
|
||||
record,
|
||||
recorder=rec,
|
||||
default_module="qlib.workflow.record_temp",
|
||||
try_kwargs={"model": model, "dataset": dataset},
|
||||
)
|
||||
r.generate()
|
||||
return rec
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
|
||||
def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
Task based training, will be divided into two steps.
|
||||
|
||||
@@ -210,17 +141,14 @@ def task_train(task_config: dict, experiment_name: str, recorder_name: str = Non
|
||||
The config of a task.
|
||||
experiment_name: str
|
||||
The name of experiment
|
||||
recorder_name: str
|
||||
The name of recorder
|
||||
|
||||
Returns
|
||||
----------
|
||||
Recorder: The instance of the recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
|
||||
_log_task_info(task_config)
|
||||
_exe_task(task_config)
|
||||
return R.get_recorder()
|
||||
recorder = begin_task_train(task_config, experiment_name)
|
||||
recorder = end_task_train(recorder, experiment_name)
|
||||
return recorder
|
||||
|
||||
|
||||
class Trainer:
|
||||
@@ -276,30 +204,6 @@ class Trainer:
|
||||
def __call__(self, *args, **kwargs) -> list:
|
||||
return self.end_train(self.train(*args, **kwargs))
|
||||
|
||||
def has_worker(self) -> bool:
|
||||
"""
|
||||
Some trainer has backend worker to support parallel training
|
||||
This method can tell if the worker is enabled.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
if the worker is enabled
|
||||
|
||||
"""
|
||||
return False
|
||||
|
||||
def worker(self):
|
||||
"""
|
||||
start the worker
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the worker is not supported
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `worker` method")
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""
|
||||
@@ -348,7 +252,7 @@ class TrainerR(Trainer):
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
recs = []
|
||||
for task in tqdm(tasks):
|
||||
for task in tasks:
|
||||
rec = train_func(task, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
@@ -553,9 +457,6 @@ class TrainerRM(Trainer):
|
||||
task_pool = experiment_name
|
||||
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
|
||||
|
||||
def has_worker(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -678,6 +579,3 @@ class DelayTrainerRM(TrainerRM):
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
)
|
||||
|
||||
def has_worker(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -50,8 +50,8 @@ RECORD_CONFIG = [
|
||||
def get_data_handler_config(
|
||||
start_time="2008-01-01",
|
||||
end_time="2020-08-01",
|
||||
fit_start_time="<dataset.kwargs.segments.train.0>",
|
||||
fit_end_time="<dataset.kwargs.segments.train.1>",
|
||||
fit_start_time="2008-01-01",
|
||||
fit_end_time="2014-12-31",
|
||||
instruments=CSI300_MARKET,
|
||||
):
|
||||
return {
|
||||
|
||||
@@ -578,7 +578,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
|
||||
"""get trading date with shift bias wil cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -587,25 +587,14 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="
|
||||
current date
|
||||
shift : int
|
||||
clip_shift: bool
|
||||
align : Optional[str]
|
||||
When align is None, this function will raise ValueError if `trading_date` is not a trading date
|
||||
when align is "left"/"right", it will try to align to left/right nearest trading date before shifting when `trading_date` is not a trading date
|
||||
|
||||
"""
|
||||
from qlib.data import D
|
||||
|
||||
cal = D.calendar(future=future, freq=freq)
|
||||
trading_date = pd.to_datetime(trading_date)
|
||||
if align is None:
|
||||
if trading_date not in list(cal):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
_index = bisect.bisect_left(cal, trading_date)
|
||||
elif align == "left":
|
||||
_index = bisect.bisect_right(cal, trading_date) - 1
|
||||
elif align == "right":
|
||||
_index = bisect.bisect_left(cal, trading_date)
|
||||
else:
|
||||
raise ValueError(f"align with value `{align}` is not supported")
|
||||
if pd.to_datetime(trading_date) not in list(cal):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
_index = bisect.bisect_left(cal, trading_date)
|
||||
shift_index = _index + shift
|
||||
if shift_index < 0 or shift_index >= len(cal):
|
||||
if clip_shift:
|
||||
|
||||
@@ -106,7 +106,7 @@ class FileManager(ObjManager):
|
||||
|
||||
def save_obj(self, obj, name):
|
||||
with (self.path / name).open("wb") as f:
|
||||
pickle.dump(obj, f, protocol=C.dump_protocol_version)
|
||||
pickle.dump(obj, f)
|
||||
|
||||
def save_objs(self, obj_name_l):
|
||||
for obj, name in obj_name_l:
|
||||
|
||||
@@ -8,7 +8,7 @@ from . import lazy_sort_index
|
||||
from .time import Freq, cal_sam_minute
|
||||
|
||||
|
||||
def resam_calendar(calendar_raw: np.ndarray, freq_raw: Union[str, Freq], freq_sam: Union[str, Freq]) -> np.ndarray:
|
||||
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
|
||||
"""
|
||||
Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam
|
||||
Assumption:
|
||||
@@ -28,36 +28,36 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: Union[str, Freq], freq_sa
|
||||
np.ndarray
|
||||
The calendar with frequency freq_sam
|
||||
"""
|
||||
freq_raw = Freq(freq_raw)
|
||||
freq_sam = Freq(freq_sam)
|
||||
raw_count, freq_raw = Freq.parse(freq_raw)
|
||||
sam_count, freq_sam = Freq.parse(freq_sam)
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
|
||||
# if freq_sam is xminute, divide each trading day into several bars evenly
|
||||
if freq_sam.base == Freq.NORM_FREQ_MINUTE:
|
||||
if freq_raw.base != Freq.NORM_FREQ_MINUTE:
|
||||
if freq_sam == Freq.NORM_FREQ_MINUTE:
|
||||
if freq_raw != Freq.NORM_FREQ_MINUTE:
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
if freq_raw.count > freq_sam.count:
|
||||
if raw_count > sam_count:
|
||||
raise ValueError("raw freq must be higher than sampling freq")
|
||||
_calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, freq_sam.count), calendar_raw)))
|
||||
_calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, sam_count), calendar_raw)))
|
||||
return _calendar_minute
|
||||
|
||||
# else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam.base == Freq.NORM_FREQ_DAY:
|
||||
return _calendar_day[:: freq_sam.count]
|
||||
if freq_sam == Freq.NORM_FREQ_DAY:
|
||||
return _calendar_day[::sam_count]
|
||||
|
||||
elif freq_sam.base == Freq.NORM_FREQ_WEEK:
|
||||
elif freq_sam == Freq.NORM_FREQ_WEEK:
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
|
||||
return _calendar_week[:: freq_sam.count]
|
||||
return _calendar_week[::sam_count]
|
||||
|
||||
elif freq_sam.base == Freq.NORM_FREQ_MONTH:
|
||||
elif freq_sam == Freq.NORM_FREQ_MONTH:
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
|
||||
return _calendar_month[:: freq_sam.count]
|
||||
return _calendar_month[::sam_count]
|
||||
else:
|
||||
raise ValueError("sampling freq must be xmin, xd, xw, xm")
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import pickle
|
||||
import dill
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from ..config import C
|
||||
|
||||
|
||||
class Serializable:
|
||||
@@ -86,8 +85,7 @@ class Serializable:
|
||||
"""
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
with Path(path).open("wb") as f:
|
||||
# pickle interface like backend; such as dill
|
||||
self.get_backend().dump(self, f, protocol=C.dump_protocol_version)
|
||||
self.get_backend().dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
@@ -118,7 +116,6 @@ class Serializable:
|
||||
Returns:
|
||||
module: pickle or dill module based on pickle_backend
|
||||
"""
|
||||
# NOTE: pickle interface like backend; such as dill
|
||||
if cls.pickle_backend == "pickle":
|
||||
return pickle
|
||||
elif cls.pickle_backend == "dill":
|
||||
@@ -143,4 +140,4 @@ class Serializable:
|
||||
obj.to_pickle(path)
|
||||
else:
|
||||
with path.open("wb") as f:
|
||||
pickle.dump(obj, f, protocol=C.dump_protocol_version)
|
||||
pickle.dump(obj, f)
|
||||
|
||||
@@ -5,7 +5,7 @@ Time related utils are compiled in this script
|
||||
"""
|
||||
import bisect
|
||||
from datetime import datetime, time, date
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Tuple, Union
|
||||
import functools
|
||||
import re
|
||||
|
||||
@@ -69,29 +69,13 @@ class Freq:
|
||||
NORM_FREQ_MONTH = "month"
|
||||
NORM_FREQ_WEEK = "week"
|
||||
NORM_FREQ_DAY = "day"
|
||||
NORM_FREQ_MINUTE = "min" # using min instead of minute for align with Qlib's data filename
|
||||
NORM_FREQ_MINUTE = "minute"
|
||||
SUPPORT_CAL_LIST = [NORM_FREQ_MINUTE, NORM_FREQ_DAY] # FIXME: this list should from data
|
||||
|
||||
MIN_CAL = get_min_cal()
|
||||
|
||||
def __init__(self, freq: Union[str, "Freq"]) -> None:
|
||||
if isinstance(freq, str):
|
||||
self.count, self.base = self.parse(freq)
|
||||
elif isinstance(freq, Freq):
|
||||
self.count, self.base = freq.count, freq.base
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def __eq__(self, freq):
|
||||
freq = Freq(freq)
|
||||
return freq.count == self.count and freq.base == self.base
|
||||
|
||||
def __str__(self):
|
||||
# trying to align to the filename of Qlib: day, 30min, 5min, 1min...
|
||||
return f"{self.count if self.count != 1 or self.base != 'day' else ''}{self.base}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({str(self)})"
|
||||
def __init__(self, freq: str) -> None:
|
||||
self.count, self.base = self.parse(freq)
|
||||
|
||||
@staticmethod
|
||||
def parse(freq: str) -> Tuple[int, str]:
|
||||
@@ -175,14 +159,14 @@ class Freq:
|
||||
Freq.NORM_FREQ_WEEK: 7 * 60 * 24,
|
||||
Freq.NORM_FREQ_MONTH: 30 * 7 * 60 * 24,
|
||||
}
|
||||
left_freq = Freq(left_frq)
|
||||
left_minutes = left_freq.count * minutes_map[left_freq.base]
|
||||
right_freq = Freq(right_freq)
|
||||
right_minutes = right_freq.count * minutes_map[right_freq.base]
|
||||
left_freq = Freq.parse(left_frq)
|
||||
left_minutes = left_freq[0] * minutes_map[left_freq[1]]
|
||||
right_freq = Freq.parse(right_freq)
|
||||
right_minutes = right_freq[0] * minutes_map[right_freq[1]]
|
||||
return left_minutes - right_minutes
|
||||
|
||||
@staticmethod
|
||||
def get_recent_freq(base_freq: Union[str, "Freq"], freq_list: List[Union[str, "Freq"]]) -> Optional["Freq"]:
|
||||
def get_recent_freq(base_freq: str, freq_list: List[str]) -> str:
|
||||
"""Get the closest freq to base_freq from freq_list
|
||||
|
||||
Parameters
|
||||
@@ -192,22 +176,17 @@ class Freq:
|
||||
|
||||
Returns
|
||||
-------
|
||||
if the recent frequency is found
|
||||
Freq
|
||||
else:
|
||||
None
|
||||
|
||||
"""
|
||||
base_freq = Freq(base_freq)
|
||||
# use the nearest freq greater than 0
|
||||
_freq_minutes = []
|
||||
min_freq = None
|
||||
for _freq in freq_list:
|
||||
freq = Freq(_freq)
|
||||
_min_delta = Freq.get_min_delta(base_freq, _freq)
|
||||
if _min_delta < 0:
|
||||
continue
|
||||
if min_freq is None:
|
||||
min_freq = (_min_delta, str(_freq))
|
||||
min_freq = (_min_delta, _freq)
|
||||
continue
|
||||
min_freq = min_freq if min_freq[0] <= _min_delta else (_min_delta, _freq)
|
||||
return min_freq[1] if min_freq else None
|
||||
|
||||
@@ -90,24 +90,14 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
SZ300676 -0.001321
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, to_date=None, from_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
|
||||
"""
|
||||
Init PredUpdater.
|
||||
|
||||
Expected behavior in following cases:
|
||||
- if `to_date` is greater than the max date in the calendar, the data will be updated to the latest date
|
||||
- if there are data before `from_date` or after `to_date`, only the data between `from_date` and `to_date` are affected.
|
||||
|
||||
Args:
|
||||
record : Recorder
|
||||
to_date :
|
||||
update to prediction to the `to_date`
|
||||
if to_date is None:
|
||||
data will updated to the latest date.
|
||||
from_date :
|
||||
the update will start from `from_date`
|
||||
if from_date is None:
|
||||
the updating will occur on the next tick after the latest data in historical data
|
||||
hist_ref : int
|
||||
Sometimes, the dataset will have historical depends.
|
||||
Leave the problem to users to set the length of historical dependency
|
||||
@@ -137,16 +127,13 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
)
|
||||
to_date = latest_date
|
||||
self.to_date = to_date
|
||||
|
||||
# FIXME: it will raise error when running routine with delay trainer
|
||||
# should we use another prediction updater for delay trainer?
|
||||
self.old_data: pd.DataFrame = record.load_object(fname)
|
||||
if from_date is None:
|
||||
# dropna is for being compatible to some data with future information(e.g. label)
|
||||
# The recent label data should be updated together
|
||||
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
|
||||
else:
|
||||
self.last_end = get_date_by_shift(from_date, -1, align="right")
|
||||
|
||||
# dropna is for being compatible to some data with future information(e.g. label)
|
||||
# The recent label data should be updated together
|
||||
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
|
||||
|
||||
def prepare_data(self) -> DatasetH:
|
||||
"""
|
||||
@@ -200,15 +187,6 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
...
|
||||
|
||||
|
||||
def _replace_range(data, new_data):
|
||||
dates = new_data.index.get_level_values("datetime")
|
||||
data = data.sort_index()
|
||||
data = data.drop(data.loc[dates.min() : dates.max()].index)
|
||||
cb_data = pd.concat([data, new_data], axis=0)
|
||||
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
|
||||
return cb_data
|
||||
|
||||
|
||||
class PredUpdater(DSBasedUpdater):
|
||||
"""
|
||||
Update the prediction in the Recorder
|
||||
@@ -218,9 +196,11 @@ class PredUpdater(DSBasedUpdater):
|
||||
# Load model
|
||||
model = self.rmdl.get_model()
|
||||
new_pred: pd.Series = model.predict(dataset)
|
||||
data = _replace_range(self.old_data, new_pred.to_frame("score"))
|
||||
|
||||
cb_pred = pd.concat([self.old_data, new_pred.to_frame("score")], axis=0)
|
||||
cb_pred = cb_pred.sort_index()
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
|
||||
return data
|
||||
return cb_pred
|
||||
|
||||
|
||||
class LabelUpdater(DSBasedUpdater):
|
||||
@@ -236,5 +216,6 @@ class LabelUpdater(DSBasedUpdater):
|
||||
|
||||
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
|
||||
new_label = SignalRecord.generate_label(dataset)
|
||||
cb_data = _replace_range(self.old_data.sort_index(), new_label)
|
||||
cb_data = pd.concat([self.old_data, new_label], axis=0)
|
||||
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
|
||||
return cb_data
|
||||
|
||||
@@ -158,7 +158,7 @@ class OnlineToolR(OnlineTool):
|
||||
exp_name = self._get_exp_name(exp_name)
|
||||
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
|
||||
def update_online_pred(self, to_date=None, from_date=None, exp_name: str = None):
|
||||
def update_online_pred(self, to_date=None, exp_name: str = None):
|
||||
"""
|
||||
Update the predictions of online models to to_date.
|
||||
|
||||
@@ -176,7 +176,7 @@ class OnlineToolR(OnlineTool):
|
||||
if issubclass(cls, TSDatasetH):
|
||||
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
|
||||
try:
|
||||
updater = PredUpdater(rec, to_date=to_date, from_date=from_date, hist_ref=hist_ref)
|
||||
updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref)
|
||||
except LoadObjectError as e:
|
||||
# skip the recorder without pred
|
||||
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.backtest import executor
|
||||
import re
|
||||
import logging
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import Union, List, Optional
|
||||
from typing import Union, List
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from ..contrib.evaluate import risk_analysis, indicator_analysis
|
||||
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.handler import DataHandlerLP
|
||||
from ..backtest import backtest as normal_backtest
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict, class_casting
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
|
||||
@@ -209,7 +215,6 @@ class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
|
||||
artifact_path = "hg_sig_analysis"
|
||||
depend_cls = SignalRecord
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder)
|
||||
@@ -265,31 +270,20 @@ class SigAnaRecord(RecordTemp):
|
||||
self.label_col = label_col
|
||||
self.skip_existing = skip_existing
|
||||
|
||||
def generate(self, label: Optional[pd.DataFrame] = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
label : Optional[pd.DataFrame]
|
||||
Label should be a dataframe.
|
||||
"""
|
||||
def generate(self, **kwargs):
|
||||
if self.skip_existing:
|
||||
try:
|
||||
self.check(include_self=True, parents=False)
|
||||
except FileNotFoundError:
|
||||
pass # continue to generating metrics
|
||||
else:
|
||||
logger.info("The results has previously generated, Generation skipped.")
|
||||
logger.info("The results has previously generated, generation skipped.")
|
||||
return
|
||||
|
||||
try:
|
||||
self.check()
|
||||
except FileNotFoundError:
|
||||
logger.warning("The dependent data does not exists. Generation skipped.")
|
||||
return
|
||||
self.check()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
if label is None:
|
||||
label = self.load("label.pkl")
|
||||
label = self.load("label.pkl")
|
||||
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
|
||||
logger.warn(f"Empty label.")
|
||||
return
|
||||
@@ -401,8 +395,8 @@ class PortAnaRecord(RecordTemp):
|
||||
if executor_config["kwargs"].get("generate_portfolio_metrics", False):
|
||||
_count, _freq = Freq.parse(executor_config["kwargs"]["time_per_step"])
|
||||
ret_freq.append(f"{_count}{_freq}")
|
||||
if "inner_executor" in executor_config["kwargs"]:
|
||||
ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["inner_executor"]))
|
||||
if "sub_env" in executor_config["kwargs"]:
|
||||
ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))
|
||||
return ret_freq
|
||||
|
||||
def generate(self, **kwargs):
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Callable, Dict, List
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.exp import Experiment
|
||||
|
||||
|
||||
class Collector(Serializable):
|
||||
@@ -147,9 +146,7 @@ class RecorderCollector(Collector):
|
||||
Init RecorderCollector.
|
||||
|
||||
Args:
|
||||
experiment:
|
||||
(Experiment or str): an instance of an Experiment or the name of an Experiment
|
||||
(Callable): an callable function, which returns a list of experiments
|
||||
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
|
||||
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
@@ -160,7 +157,6 @@ class RecorderCollector(Collector):
|
||||
super().__init__(process_list=process_list)
|
||||
if isinstance(experiment, str):
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
assert isinstance(experiment, (Experiment, Callable))
|
||||
self.experiment = experiment
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
@@ -196,16 +192,15 @@ class RecorderCollector(Collector):
|
||||
collect_dict = {}
|
||||
# filter records
|
||||
|
||||
if isinstance(self.experiment, Experiment):
|
||||
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
|
||||
recs = list(self.experiment.list_recorders(**self.list_kwargs).values())
|
||||
elif isinstance(self.experiment, Callable):
|
||||
recs = self.experiment()
|
||||
|
||||
recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)]
|
||||
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
|
||||
recs = self.experiment.list_recorders(**self.list_kwargs)
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
logger = get_module_logger("RecorderCollector")
|
||||
for rec in recs:
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
if self.ART_KEY_RAW == key:
|
||||
|
||||
@@ -27,7 +27,6 @@ from qlib import auto_init, get_module_logger
|
||||
from tqdm.cli import tqdm
|
||||
|
||||
from .utils import get_mongodb
|
||||
from ...config import C
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -94,7 +93,6 @@ class TaskManager:
|
||||
"""
|
||||
self.task_pool: pymongo.collection.Collection = getattr(get_mongodb(), task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.logger.info(f"task_pool:{task_pool}")
|
||||
|
||||
@staticmethod
|
||||
def list() -> list:
|
||||
@@ -110,7 +108,7 @@ class TaskManager:
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = Binary(pickle.dumps(task[k], protocol=C.dump_protocol_version))
|
||||
task[k] = Binary(pickle.dumps(task[k]))
|
||||
return task
|
||||
|
||||
def _decode_task(self, task):
|
||||
@@ -361,10 +359,7 @@ class TaskManager:
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
self.task_pool.update_one(
|
||||
{"_id": task["_id"]},
|
||||
{"$set": {"status": status, "res": Binary(pickle.dumps(res, protocol=C.dump_protocol_version))}},
|
||||
)
|
||||
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
|
||||
def return_task(self, task, status=STATUS_WAITING):
|
||||
"""
|
||||
|
||||
@@ -46,7 +46,7 @@ def get_mongodb() -> Database:
|
||||
except KeyError:
|
||||
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
|
||||
raise
|
||||
get_module_logger("task").info(f"mongo config:{cfg}")
|
||||
|
||||
client = MongoClient(cfg["task_url"])
|
||||
return client.get_database(name=cfg["task_db_name"])
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# 1min data (Optional for running non-high-frequency strategies)
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
```
|
||||
|
||||
### Download US Data
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
# Data Collector
|
||||
|
||||
## Introduction
|
||||
|
||||
Scripts for data collection
|
||||
|
||||
- yahoo: get *US/CN* stock data from *Yahoo Finance*
|
||||
- fund: get fund data from *http://fund.eastmoney.com*
|
||||
- cn_index: get *CN index* from *http://www.csindex.com.cn*, *CSI300*/*CSI100*
|
||||
- us_index: get *US index* from *https://en.wikipedia.org/wiki*, *SP500*/*NASDAQ100*/*DJIA*/*SP400*
|
||||
- contrib: scripts for some auxiliary functions
|
||||
|
||||
|
||||
## Custom Data Collection
|
||||
|
||||
> Specific implementation reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo
|
||||
|
||||
1. Create a dataset code directory in the current directory
|
||||
2. Add `collector.py`
|
||||
- add collector class:
|
||||
```python
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
class UserCollector(BaseCollector):
|
||||
...
|
||||
```
|
||||
- add normalize class:
|
||||
```python
|
||||
class UserNormalzie(BaseNormalize):
|
||||
...
|
||||
```
|
||||
- add `CLI` class:
|
||||
```python
|
||||
class Run(BaseRun):
|
||||
...
|
||||
```
|
||||
3. add `README.md`
|
||||
4. add `requirements.txt`
|
||||
|
||||
|
||||
## Description of dataset
|
||||
|
||||
| | Basic data |
|
||||
|------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
|
||||
| Features | **Price/Volume**: <br> - $close/$open/$low/$high/$volume/$change/$factor |
|
||||
| Calendar | **\<freq>.txt**: <br> - day.txt<br> - 1min.txt |
|
||||
| Instruments | **\<market>.txt**: <br> - required: **all.txt**; <br> - csi300.txt/csi500.txt/sp500.txt |
|
||||
|
||||
- `Features`: data, **digital**
|
||||
- if not **adjusted**, **factor=1**
|
||||
|
||||
### Data-dependent component
|
||||
|
||||
> To make the component running correctly, the dependent data are required
|
||||
|
||||
| Component | required data |
|
||||
|---------------------------------------------------|--------------------------------|
|
||||
| Data retrieval | Features, Calendar, Instrument |
|
||||
| Backtest | **Features[Price/Volume]**, Calendar, Instruments |
|
||||
@@ -6,12 +6,13 @@ import abc
|
||||
import sys
|
||||
import importlib
|
||||
from io import BytesIO
|
||||
from typing import List, Iterable
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
@@ -21,10 +22,12 @@ from data_collector.index import IndexBase
|
||||
from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry
|
||||
|
||||
|
||||
NEW_COMPANIES_URL = "https://csi-web-dev.oss-cn-shanghai-finance-1-pub.aliyuncs.com/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls"
|
||||
|
||||
|
||||
INDEX_CHANGES_URL = "https://www.csindex.com.cn/csindex-home/search/search-content?lang=cn&searchInput=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC&pageNum={page_num}&pageSize={page_size}&sortField=date&dateRange=all&contentType=announcement"
|
||||
# INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A"
|
||||
# 2020-11-27 Announcement title change
|
||||
INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89"
|
||||
|
||||
REQ_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.101 Safari/537.36 Edg/91.0.864.48"
|
||||
@@ -52,11 +55,7 @@ class CSIIndex(IndexBase):
|
||||
-------
|
||||
calendar list
|
||||
"""
|
||||
_calendar = getattr(self, "_calendar_list", None)
|
||||
if not _calendar:
|
||||
_calendar = get_calendar_list(bench_code=self.index_name.upper())
|
||||
setattr(self, "_calendar_list", _calendar)
|
||||
return _calendar
|
||||
return get_calendar_list(bench_code=self.index_name.upper())
|
||||
|
||||
@property
|
||||
def new_companies_url(self) -> str:
|
||||
@@ -97,27 +96,6 @@ class CSIIndex(IndexBase):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""formatting the datetime in an instrument
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst_df: pd.DataFrame
|
||||
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
if self.freq != "day":
|
||||
inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=9, minutes=30)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=15, minutes=0)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
return inst_df
|
||||
|
||||
def get_changes(self) -> pd.DataFrame:
|
||||
"""get companies changes
|
||||
|
||||
@@ -136,8 +114,7 @@ class CSIIndex(IndexBase):
|
||||
res = []
|
||||
for _url in self._get_change_notices_url():
|
||||
_df = self._read_change_from_url(_url)
|
||||
if not _df.empty:
|
||||
res.append(_df)
|
||||
res.append(_df)
|
||||
logger.info("get companies changes finish")
|
||||
return pd.concat(res, sort=False)
|
||||
|
||||
@@ -157,56 +134,6 @@ class CSIIndex(IndexBase):
|
||||
symbol = f"{int(symbol):06}"
|
||||
return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}"
|
||||
|
||||
def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.Timestamp) -> pd.DataFrame:
|
||||
content = retry_request(excel_url, exclude_status=[404]).content
|
||||
_io = BytesIO(content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
fp.write(content)
|
||||
tmp = []
|
||||
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df["type"] = _type
|
||||
_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_df)
|
||||
df = pd.concat(tmp)
|
||||
return df
|
||||
|
||||
def _parse_table(self, content: str, add_date: pd.DataFrame, remove_date: pd.DataFrame) -> pd.DataFrame:
|
||||
df = pd.DataFrame()
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(content):
|
||||
if _df.shape[-1] != 4:
|
||||
continue
|
||||
_tmp_count += 1
|
||||
if self.html_table_index + 1 > _tmp_count:
|
||||
continue
|
||||
tmp = []
|
||||
for _s, _type, _date in [
|
||||
(_df.iloc[2:, 0], self.REMOVE, remove_date),
|
||||
(_df.iloc[2:, 2], self.ADD, add_date),
|
||||
]:
|
||||
_tmp_df = pd.DataFrame()
|
||||
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
|
||||
_tmp_df["type"] = _type
|
||||
_tmp_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
df.to_csv(
|
||||
str(
|
||||
self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
|
||||
).resolve()
|
||||
)
|
||||
)
|
||||
break
|
||||
return df
|
||||
|
||||
def _read_change_from_url(self, url: str) -> pd.DataFrame:
|
||||
"""read change from url
|
||||
|
||||
@@ -226,60 +153,75 @@ class CSIIndex(IndexBase):
|
||||
date: pd.Timestamp
|
||||
type: str, value from ["add", "remove"]
|
||||
"""
|
||||
resp = retry_request(url).json()["data"]
|
||||
title = resp["title"]
|
||||
if not title.startswith("关于"):
|
||||
return pd.DataFrame()
|
||||
if "沪深300" not in title:
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f"load index data from https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}")
|
||||
_text = resp["content"]
|
||||
resp = retry_request(url)
|
||||
_text = resp.text
|
||||
date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text)
|
||||
if len(date_list) >= 2:
|
||||
add_date = pd.Timestamp("-".join(date_list[0]))
|
||||
else:
|
||||
_date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]))
|
||||
add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0)
|
||||
if "盘后" in _text or "市后" in _text:
|
||||
add_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=1)
|
||||
remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1)
|
||||
|
||||
excel_url = None
|
||||
if resp.get("enclosureList", []):
|
||||
excel_url = resp["enclosureList"][0]["fileUrl"]
|
||||
else:
|
||||
excel_url_list = re.findall('.*href="(.*?xls.*?)".*', _text)
|
||||
if excel_url_list:
|
||||
excel_url = excel_url_list[0]
|
||||
if not excel_url.startswith("http"):
|
||||
excel_url = excel_url if excel_url.startswith("/") else "/" + excel_url
|
||||
excel_url = f"http://www.csindex.com.cn{excel_url}"
|
||||
if excel_url:
|
||||
logger.info(f"get {add_date} changes from excel, title={title}, excel_url={excel_url}")
|
||||
try:
|
||||
df = self._parse_excel(excel_url, add_date, remove_date)
|
||||
except ValueError:
|
||||
logger.warning(f"error downloading file: {excel_url}, will parse the table from the content")
|
||||
df = self._parse_table(_text, add_date, remove_date)
|
||||
else:
|
||||
logger.info(f"get {add_date} changes from url content, title={title}")
|
||||
df = self._parse_table(_text, add_date, remove_date)
|
||||
logger.info(f"get {add_date} changes")
|
||||
try:
|
||||
excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0]
|
||||
content = retry_request(f"http://www.csindex.com.cn{excel_url}", exclude_status=[404]).content
|
||||
_io = BytesIO(content)
|
||||
df_map = pd.read_excel(_io, sheet_name=None)
|
||||
with self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.{excel_url.split('.')[-1]}"
|
||||
).open("wb") as fp:
|
||||
fp.write(content)
|
||||
tmp = []
|
||||
for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]:
|
||||
_df = df_map[_s_name]
|
||||
_df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]]
|
||||
_df = _df.applymap(self.normalize_symbol)
|
||||
_df.columns = [self.SYMBOL_FIELD_NAME]
|
||||
_df["type"] = _type
|
||||
_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_df)
|
||||
df = pd.concat(tmp)
|
||||
except Exception as e:
|
||||
df = None
|
||||
_tmp_count = 0
|
||||
for _df in pd.read_html(resp.content):
|
||||
if _df.shape[-1] != 4:
|
||||
continue
|
||||
_tmp_count += 1
|
||||
if self.html_table_index + 1 > _tmp_count:
|
||||
continue
|
||||
tmp = []
|
||||
for _s, _type, _date in [
|
||||
(_df.iloc[2:, 0], self.REMOVE, remove_date),
|
||||
(_df.iloc[2:, 2], self.ADD, add_date),
|
||||
]:
|
||||
_tmp_df = pd.DataFrame()
|
||||
_tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol)
|
||||
_tmp_df["type"] = _type
|
||||
_tmp_df[self.DATE_FIELD_NAME] = _date
|
||||
tmp.append(_tmp_df)
|
||||
df = pd.concat(tmp)
|
||||
df.to_csv(
|
||||
str(
|
||||
self.cache_dir.joinpath(
|
||||
f"{self.index_name.lower()}_changes_{add_date.strftime('%Y%m%d')}.csv"
|
||||
).resolve()
|
||||
)
|
||||
)
|
||||
break
|
||||
return df
|
||||
|
||||
def _get_change_notices_url(self) -> Iterable[str]:
|
||||
def _get_change_notices_url(self) -> List[str]:
|
||||
"""get change notices url
|
||||
|
||||
Returns
|
||||
-------
|
||||
[url1, url2]
|
||||
"""
|
||||
page_num = 1
|
||||
page_size = 5
|
||||
data = retry_request(self.changes_url.format(page_size=page_size, page_num=page_num)).json()
|
||||
data = retry_request(self.changes_url.format(page_size=data["total"], page_num=page_num)).json()
|
||||
for item in data["data"]:
|
||||
yield f"https://www.csindex.com.cn/csindex-home/announcement/queryAnnouncementById?id={item['id']}"
|
||||
resp = retry_request(self.changes_url)
|
||||
html = etree.HTML(resp.text)
|
||||
return html.xpath("//*[@id='itemContainer']//li/a/@href")
|
||||
|
||||
def get_new_companies(self) -> pd.DataFrame:
|
||||
"""
|
||||
@@ -307,7 +249,7 @@ class CSIIndex(IndexBase):
|
||||
df = df.iloc[:, [0, 4]]
|
||||
df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME]
|
||||
df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol)
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str))
|
||||
df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD])
|
||||
df[self.START_DATE_FIELD] = self.bench_start_date
|
||||
logger.info("end of get new companies.")
|
||||
return df
|
||||
@@ -324,7 +266,7 @@ class CSI300(CSIIndex):
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
class CSI100(CSIIndex):
|
||||
@@ -338,16 +280,11 @@ class CSI100(CSIIndex):
|
||||
|
||||
@property
|
||||
def html_table_index(self):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
|
||||
def get_instruments(
|
||||
qlib_dir: str,
|
||||
index_name: str,
|
||||
method: str = "parse_instruments",
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -359,8 +296,6 @@ def get_instruments(
|
||||
index name, value from ["csi100", "csi300"]
|
||||
method: str
|
||||
method, value from ["parse_instruments", "save_new_companies"]
|
||||
freq: str
|
||||
freq, value from ["day", "1min"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
@@ -377,7 +312,7 @@ def get_instruments(
|
||||
"""
|
||||
_cur_module = importlib.import_module("data_collector.cn_index.collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
logure
|
||||
fire
|
||||
requests
|
||||
pandas
|
||||
|
||||
@@ -26,14 +26,7 @@ class IndexBase:
|
||||
ADD = "add"
|
||||
INST_PREFIX = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
qlib_dir: [str, Path] = None,
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
):
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -42,8 +35,6 @@ class IndexBase:
|
||||
index name
|
||||
qlib_dir: str
|
||||
qlib directory, by default Path(__file__).resolve().parent.joinpath("qlib_data")
|
||||
freq: str
|
||||
freq, value from ["day", "1min"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
@@ -58,7 +49,6 @@ class IndexBase:
|
||||
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
||||
self._request_retry = request_retry
|
||||
self._retry_sleep = retry_sleep
|
||||
self.freq = freq
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
@@ -116,21 +106,6 @@ class IndexBase:
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_changes")
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""formatting the datetime in an instrument
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst_df: pd.DataFrame
|
||||
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite format_datetime")
|
||||
|
||||
def save_new_companies(self):
|
||||
"""save new companies
|
||||
|
||||
@@ -231,7 +206,6 @@ class IndexBase:
|
||||
_inst_prefix = self.INST_PREFIX.strip()
|
||||
if _inst_prefix:
|
||||
inst_df["save_inst"] = inst_df[self.SYMBOL_FIELD_NAME].apply(lambda x: f"{_inst_prefix}{x}")
|
||||
inst_df = self.format_datetime(inst_df)
|
||||
inst_df.to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None
|
||||
)
|
||||
|
||||
@@ -37,16 +37,9 @@ class WIKIIndex(IndexBase):
|
||||
# https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
||||
INST_PREFIX = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
qlib_dir: [str, Path] = None,
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
):
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
super(WIKIIndex, self).__init__(
|
||||
index_name=index_name, qlib_dir=qlib_dir, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
|
||||
self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}"
|
||||
@@ -78,24 +71,6 @@ class WIKIIndex(IndexBase):
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_changes")
|
||||
|
||||
def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""formatting the datetime in an instrument
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inst_df: pd.DataFrame
|
||||
inst_df.columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
if self.freq != "day":
|
||||
inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply(
|
||||
lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59)).strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
return inst_df
|
||||
|
||||
@property
|
||||
def calendar_list(self) -> List[pd.Timestamp]:
|
||||
"""get history trading date
|
||||
@@ -270,12 +245,7 @@ class SP400Index(WIKIIndex):
|
||||
|
||||
|
||||
def get_instruments(
|
||||
qlib_dir: str,
|
||||
index_name: str,
|
||||
method: str = "parse_instruments",
|
||||
freq: str = "day",
|
||||
request_retry: int = 5,
|
||||
retry_sleep: int = 3,
|
||||
qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -287,8 +257,6 @@ def get_instruments(
|
||||
index name, value from ["SP500", "NASDAQ100", "DJIA", "SP400"]
|
||||
method: str
|
||||
method, value from ["parse_instruments", "save_new_companies"]
|
||||
freq: str
|
||||
freq, value from ["day", "1min"]
|
||||
request_retry: int
|
||||
request retry, by default 5
|
||||
retry_sleep: int
|
||||
@@ -297,15 +265,15 @@ def get_instruments(
|
||||
Examples
|
||||
-------
|
||||
# parse instruments
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method parse_instruments
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method save_new_companies
|
||||
$ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
|
||||
"""
|
||||
_cur_module = importlib.import_module("data_collector.us_index.collector")
|
||||
obj = getattr(_cur_module, f"{index_name.upper()}Index")(
|
||||
qlib_dir=qlib_dir, index_name=index_name, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep
|
||||
)
|
||||
getattr(obj, method)()
|
||||
|
||||
|
||||
@@ -27,8 +27,7 @@ pip install -r requirements.txt
|
||||
## Collector Data
|
||||
|
||||
### Get Qlib data(`bin file`)
|
||||
> `qlib-data` from *YahooFinance*, is the data that has been dumped and can be used directly in `qlib`.
|
||||
> This ready-made qlib-data is not updated regularly. If users want the latest data, please follow [these steps](#collector-yahoofinance-data-to-qlib) download the latest data.
|
||||
> `qlib-data` from *YahooFinance*, is the data that has been dumped and can be used directly in `qlib`
|
||||
|
||||
- get data: `python scripts/get_data.py qlib_data`
|
||||
- parameters:
|
||||
@@ -58,8 +57,7 @@ pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Collector *YahooFinance* data to qlib
|
||||
> collector *YahooFinance* data and *dump* into `qlib` format.
|
||||
> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data.
|
||||
> collector *YahooFinance* data and *dump* into `qlib` format
|
||||
1. download data to csv: `python scripts/data_collector/yahoo/collector.py download_data`
|
||||
|
||||
- parameters:
|
||||
|
||||
@@ -601,19 +601,11 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
# - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits.
|
||||
# - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)`
|
||||
def _calc_factor(df_1d: pd.DataFrame):
|
||||
try:
|
||||
_date = pd.Timestamp(pd.Timestamp(df_1d[self._date_field_name].iloc[0]).date())
|
||||
df_1d["factor"] = (
|
||||
data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"]
|
||||
)
|
||||
df_1d["paused"] = data_1d.loc[_date]["paused"]
|
||||
except Exception:
|
||||
df_1d["factor"] = np.nan
|
||||
df_1d["paused"] = np.nan
|
||||
return df_1d
|
||||
|
||||
df = df.groupby([df[self._date_field_name].dt.date]).apply(_calc_factor)
|
||||
df["date_tmp"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date())
|
||||
df.set_index("date_tmp", inplace=True)
|
||||
df.loc[:, "factor"] = data_1d["close"] / df["close"]
|
||||
df.loc[:, "paused"] = data_1d["paused"]
|
||||
df.reset_index("date_tmp", drop=True, inplace=True)
|
||||
|
||||
if self.CONSISTENT_1d:
|
||||
# the date sequence is consistent with 1d
|
||||
|
||||
23
setup.py
23
setup.py
@@ -6,21 +6,6 @@ import numpy
|
||||
|
||||
from setuptools import find_packages, setup, Extension
|
||||
|
||||
|
||||
def read(rel_path: str) -> str:
|
||||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
with open(os.path.join(here, rel_path), encoding="utf-8") as fp:
|
||||
return fp.read()
|
||||
|
||||
|
||||
def get_version(rel_path: str) -> str:
|
||||
for line in read(rel_path).splitlines():
|
||||
if line.startswith("__version__"):
|
||||
delim = '"' if '"' in line else "'"
|
||||
return line.split(delim)[1]
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
|
||||
# Package meta-data.
|
||||
NAME = "pyqlib"
|
||||
DESCRIPTION = "A Quantitative-research Platform"
|
||||
@@ -29,7 +14,11 @@ REQUIRES_PYTHON = ">=3.5.0"
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
VERSION = get_version("qlib/__init__.py")
|
||||
CURRENT_DIR = Path(__file__).absolute().parent
|
||||
_version_src = CURRENT_DIR / "VERSION.txt"
|
||||
_version_dst = CURRENT_DIR / "qlib" / "VERSION.txt"
|
||||
copyfile(_version_src, _version_dst)
|
||||
VERSION = _version_dst.read_text(encoding="utf-8").strip()
|
||||
|
||||
# Detect Cython
|
||||
try:
|
||||
@@ -58,7 +47,7 @@ REQUIRED = [
|
||||
"python-redis-lock>=3.3.1",
|
||||
"schedule>=0.6.0",
|
||||
"cvxpy>=1.0.21",
|
||||
"hyperopt==0.1.2",
|
||||
"hyperopt==0.1.1",
|
||||
"fire>=0.3.1",
|
||||
"statsmodels",
|
||||
"xlrd>=1.0.0",
|
||||
|
||||
@@ -21,7 +21,11 @@ class TestRolling(TestAutoData):
|
||||
"""
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]
|
||||
task["record"] = {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
"kwargs": {"dataset": "<DATASET>", "model": "<MODEL>"},
|
||||
}
|
||||
|
||||
exp_name = "online_srv_test"
|
||||
|
||||
@@ -53,27 +57,6 @@ class TestRolling(TestAutoData):
|
||||
|
||||
online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))
|
||||
|
||||
good_pred = rec.load_object("pred.pkl")
|
||||
|
||||
mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10))
|
||||
mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2))
|
||||
mod_pred = good_pred.copy()
|
||||
|
||||
mod_pred.loc[mod_range] = -1
|
||||
mod_pred.loc[mod_range2] = -2
|
||||
|
||||
rec.save_objects(**{"pred.pkl": mod_pred})
|
||||
online_tool.update_online_pred(
|
||||
to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20)
|
||||
)
|
||||
|
||||
updated_pred = rec.load_object("pred.pkl")
|
||||
|
||||
# this range is not fixed
|
||||
self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item())
|
||||
# this range is fixed now
|
||||
self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())
|
||||
|
||||
def test_update_label(self):
|
||||
|
||||
task = copy.deepcopy(CSI300_GBDT_TASK)
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestStorage(TestAutoData):
|
||||
|
||||
"""
|
||||
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri, freq="day")
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri)
|
||||
|
||||
for inst, spans in instrument.data.items():
|
||||
assert isinstance(inst, str) and isinstance(
|
||||
@@ -88,7 +88,7 @@ class TestStorage(TestAutoData):
|
||||
|
||||
print(f"instrument['SH600000']: {instrument['SH600000']}")
|
||||
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri="not_found", freq="day")
|
||||
instrument = InstrumentStorage(market="csi300", provider_uri="not_found")
|
||||
with self.assertRaises(ValueError):
|
||||
print(instrument.data)
|
||||
|
||||
@@ -163,9 +163,8 @@ class TestStorage(TestAutoData):
|
||||
|
||||
feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
print(feature[0])
|
||||
with self.assertRaises(ValueError):
|
||||
print(feature[:].empty)
|
||||
with self.assertRaises(ValueError):
|
||||
print(feature.data.empty)
|
||||
assert feature[0] == (None, None), "FeatureStorage does not exist, feature[i] should return `(None, None)`"
|
||||
assert feature[:].empty, "FeatureStorage does not exist, feature[:] should return `pd.Series(dtype=np.float32)`"
|
||||
assert (
|
||||
feature.data.empty
|
||||
), "FeatureStorage does not exist, feature.data should return `pd.Series(dtype=np.float32)`"
|
||||
|
||||
@@ -201,7 +201,6 @@ class TestAllFlow(TestAutoData):
|
||||
0.10,
|
||||
"backtest failed",
|
||||
)
|
||||
self.assertTrue(not analyze_df.isna().any().any(), "backtest failed")
|
||||
|
||||
def test_3_expmanager(self):
|
||||
pass_default, pass_current, uri_path = fake_experiment()
|
||||
|
||||
@@ -75,35 +75,6 @@ class TestDataset(TestAutoData):
|
||||
equal = np.isclose(data_from_df, data_from_ds)
|
||||
self.assertTrue(equal[~np.isnan(data_from_df)].all())
|
||||
|
||||
if False:
|
||||
# 3) get both index and data
|
||||
# NOTE: We don't want to reply on pytorch, so this test can't be included. It is just a example
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
class IdxSampler:
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
def __getitem__(self, i: int):
|
||||
return self.sampler[i], i
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sampler)
|
||||
|
||||
i = len(tsds) - 1
|
||||
idx = tsds.get_index()
|
||||
tsds[i]
|
||||
idx[i]
|
||||
|
||||
s_w_i = IdxSampler(tsds)
|
||||
test_loader = DataLoader(s_w_i)
|
||||
|
||||
s_w_i[3]
|
||||
for data, i in test_loader:
|
||||
break
|
||||
print(data.shape)
|
||||
print(idx[i])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=10)
|
||||
|
||||
Reference in New Issue
Block a user