mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge remote-tracking branch 'qlib/main' into save_inst
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -13,7 +13,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, ubuntu-16.04, ubuntu-18.04, ubuntu-20.04, macos-latest]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
__pycache__/
|
||||
|
||||
*.pyc
|
||||
*.pyd
|
||||
*.so
|
||||
*.ipynb
|
||||
.ipynb_checkpoints
|
||||
|
||||
27
README.md
27
README.md
@@ -69,7 +69,20 @@ Here is a quick **[demo](https://terminalizer.com/view/3f24561a4470)** shows how
|
||||
|
||||
## Installation
|
||||
|
||||
Users can easily install ``Qlib`` by pip according to the following command(Currently, Qlib only support Python 3.6, 3.7 and 3.8).
|
||||
This table demonstrates the supported Python version of `Qlib`:
|
||||
| | install with pip | install from source | plot |
|
||||
| ------------- |:---------------------:|:--------------------:|:----:|
|
||||
| Python 3.6 | :heavy_check_mark: | :heavy_check_mark: (only with `Anaconda`) | :heavy_check_mark: |
|
||||
| Python 3.7 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.8 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Python 3.9 | :x: | :heavy_check_mark: | :x: |
|
||||
|
||||
**Note**:
|
||||
1. Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
2. For Python 3.9, `Qlib` supports running workflows such as training models, doing backtest and plot most of the related figures (those included in [notebook](examples/workflow_by_code.ipynb)). However, plotting for the *model performance* is not supported for now and we will fix this when the dependent packages are upgraded in the future.
|
||||
|
||||
### Install with pip
|
||||
Users can easily install ``Qlib`` by pip according to the following command.
|
||||
|
||||
```bash
|
||||
pip install pyqlib
|
||||
@@ -77,6 +90,7 @@ Users can easily install ``Qlib`` by pip according to the following command(Curr
|
||||
|
||||
**Note**: pip will install the latest stable qlib. However, the main branch of qlib is in active development. If you want to test the latest scripts or functions in the main branch. Please install qlib with the methods below.
|
||||
|
||||
### Install from source
|
||||
Also, users can install the latest dev version ``Qlib`` by the source code according to the following steps:
|
||||
|
||||
* Before installing ``Qlib`` from source, users need to install some dependencies:
|
||||
@@ -85,7 +99,6 @@ Also, users can install the latest dev version ``Qlib`` by the source code accor
|
||||
pip install numpy
|
||||
pip install --upgrade cython
|
||||
```
|
||||
**Note**: Please pay attention that installing cython in Python 3.6 will raise some error when installing ``Qlib`` from source. If users use Python 3.6 on their machines, it is recommended to *upgrade* Python to version 3.7 or use `conda`'s Python to install ``Qlib`` from source.
|
||||
|
||||
* Clone the repository and install ``Qlib`` as follows.
|
||||
* If you haven't installed qlib by the command ``pip install pyqlib`` before:
|
||||
@@ -149,6 +162,10 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
cd examples # Avoid running program under the directory contains `qlib`
|
||||
qrun benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
If users want to use `qrun` under debug mode, please use the following command:
|
||||
```bash
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
```
|
||||
The result of `qrun` is as follows, please refer to [Intraday Trading](https://qlib.readthedocs.io/en/latest/component/backtest.html) for more details about the result.
|
||||
|
||||
```bash
|
||||
@@ -209,11 +226,12 @@ Here is a list of models built on `Qlib`.
|
||||
- [GBDT based on XGBoost (Tianqi Chen, et al.)](qlib/contrib/model/xgboost.py)
|
||||
- [MLP based on pytorch](qlib/contrib/model/pytorch_nn.py)
|
||||
- [GRU based on pytorch (Kyunghyun Cho, et al.)](qlib/contrib/model/pytorch_gru.py)
|
||||
- [LSTM based on pytorcn (Sepp Hochreiter, et al.)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorcn (Yao Qin, et al.)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [LSTM based on pytorch (Sepp Hochreiter, et al.)](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [ALSTM based on pytorch (Yao Qin, et al.)](qlib/contrib/model/pytorch_alstm.py)
|
||||
- [GATs based on pytorch (Petar Velickovic, et al.)](qlib/contrib/model/pytorch_gats.py)
|
||||
- [SFM based on pytorch (Liheng Zhang, et al.)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al.)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al.)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -295,6 +313,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into
|
||||
|
||||
|
||||
# Related Reports
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
@@ -126,17 +126,17 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
The arguments of `--include_fields` should correspond with the column names of CSV files. The columns names of dataset provided by ``Qlib`` should include open, close, high, low, volume and factor at least.
|
||||
|
||||
- `open`
|
||||
The opening price
|
||||
The adjusted opening price
|
||||
- `close`
|
||||
The closing price
|
||||
The adjusted closing price
|
||||
- `high`
|
||||
The highest price
|
||||
The adjusted highest price
|
||||
- `low`
|
||||
The lowest price
|
||||
The adjusted lowest price
|
||||
- `volume`
|
||||
The trading volume
|
||||
The adjusted trading volume
|
||||
- `factor`
|
||||
The Restoration factor
|
||||
The Restoration factor. Normally, original_price = adj_price / factor
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
|
||||
|
||||
@@ -34,8 +34,9 @@ Here is a general view of the structure of the system:
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
This experiment management system defines a set of interface and provided a concrete implementation based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
|
||||
|
||||
Qlib Recorder
|
||||
===================
|
||||
|
||||
@@ -103,6 +103,12 @@ After saving the config into `configuration.yaml`, users could start the workflo
|
||||
|
||||
qrun configuration.yaml
|
||||
|
||||
If users want to use ``qrun`` under debug mode, please use the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -m pdb qlib/workflow/cli.py examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
.. note::
|
||||
|
||||
`qrun` will be placed in your $PATH directory when installing ``Qlib``.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
Cython
|
||||
cmake
|
||||
numpy
|
||||
scipy
|
||||
scikit-learn
|
||||
scikit-learn
|
||||
|
||||
@@ -63,6 +63,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
If Qlib fails to connect redis via `redis_host` and `redis_port`, cache mechanism will not be used! Please refer to `Cache <../component/data.html#cache>`_ for details.
|
||||
- `exp_manager`
|
||||
Type: dict, optional parameter, the setting of `experiment manager` to be used in qlib. Users can specify an experiment manager class, as well as the tracking URI for all the experiments. However, please be aware that we only support input of a dictionary in the following style for `exp_manager`. For more information about `exp_manager`, users can refer to `Recorder: Experiment Management <../component/recorder.html>`_.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, if you want to set your tracking_uri to a <specific folder>, you can initialize qlib below
|
||||
|
||||
@@ -25,8 +25,11 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
|
||||
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
BIN
examples/benchmarks/TabNet/pretrain/best.model
Normal file
Binary file not shown.
4
examples/benchmarks/TabNet/requirements.txt
Normal file
4
examples/benchmarks/TabNet/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -0,0 +1,74 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -98,6 +98,7 @@ if __name__ == "__main__":
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"return_order": True,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -105,6 +106,11 @@ if __name__ == "__main__":
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# NOTE: This line is optional
|
||||
# It demonstrates that the dataset can be used standalone.
|
||||
example_df = dataset.prepare("train")
|
||||
print(example_df.head())
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
|
||||
@@ -45,9 +45,10 @@ def init(default_conf="client", **kwargs):
|
||||
C.set_region(kwargs.get("region", C["region"] if "region" in C else REG_CN))
|
||||
|
||||
for k, v in kwargs.items():
|
||||
C[k] = v
|
||||
if k not in C:
|
||||
LOG.warning("Unrecognized config %s" % k)
|
||||
else:
|
||||
C[k] = v
|
||||
|
||||
C.resolve_path()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import multiprocessing
|
||||
|
||||
class Config:
|
||||
def __init__(self, default_conf):
|
||||
self.__dict__["_default_config"] = default_conf # avoiding conflictions with __getattr__
|
||||
self.__dict__["_default_config"] = copy.deepcopy(default_conf) # avoiding conflictions with __getattr__
|
||||
self.reset()
|
||||
|
||||
def __getitem__(self, key):
|
||||
|
||||
@@ -6,3 +6,319 @@ from .account import Account
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .backtest import backtest as backtest_func, get_date_range
|
||||
|
||||
import numpy as np
|
||||
import inspect
|
||||
from ...utils import init_instance_by_config
|
||||
from ...log import get_module_logger
|
||||
from ...config import C
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy=None,
|
||||
topk=50,
|
||||
margin=0.5,
|
||||
n_drop=5,
|
||||
risk_degree=0.95,
|
||||
str_type="dropout",
|
||||
adjust_dates=None,
|
||||
):
|
||||
"""get_strategy
|
||||
|
||||
There will be 3 ways to return a stratgy. Please follow the code.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Strategy
|
||||
an initialized strategy object
|
||||
"""
|
||||
|
||||
# There will be 3 ways to return a strategy.
|
||||
if strategy is None:
|
||||
# 1) create strategy with param `strategy`
|
||||
str_cls_dict = {
|
||||
"amount": "TopkAmountStrategy",
|
||||
"weight": "TopkWeightStrategy",
|
||||
"dropout": "TopkDropoutStrategy",
|
||||
}
|
||||
logger.info("Create new strategy ")
|
||||
from .. import strategy as strategy_pool
|
||||
|
||||
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
|
||||
strategy = str_cls(
|
||||
topk=topk,
|
||||
buffer_margin=margin,
|
||||
n_drop=n_drop,
|
||||
risk_degree=risk_degree,
|
||||
adjust_dates=adjust_dates,
|
||||
)
|
||||
elif isinstance(strategy, (dict, str)):
|
||||
# 2) create strategy with init_instance_by_config
|
||||
logger.info("Create new strategy ")
|
||||
strategy = init_instance_by_config(strategy)
|
||||
|
||||
from ..strategy.strategy import BaseStrategy
|
||||
|
||||
# else: nothing happens. 3) Use the strategy directly
|
||||
if not isinstance(strategy, BaseStrategy):
|
||||
raise TypeError("Strategy not supported")
|
||||
return strategy
|
||||
|
||||
|
||||
def get_exchange(
|
||||
pred,
|
||||
exchange=None,
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
extract_codes=False,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
NOTE: This will be faster with offline qlib.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Exchange
|
||||
an initialized Exchange object
|
||||
"""
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
if exchange is None:
|
||||
logger.info("Create new exchange")
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
if extract_codes:
|
||||
codes = sorted(pred.index.get_level_values("instrument").unique())
|
||||
else:
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values("datetime").unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
trade_dates=dates,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
min_cost=min_cost,
|
||||
trade_unit=trade_unit,
|
||||
)
|
||||
return exchange
|
||||
|
||||
|
||||
def get_executor(
|
||||
executor=None,
|
||||
trade_exchange=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""get_executor
|
||||
|
||||
There will be 3 ways to return a executor. Please follow the code.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
executor : BaseExecutor
|
||||
executor used in backtest.
|
||||
trade_exchange : Exchange
|
||||
exchange used in executor
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: BaseExecutor
|
||||
an initialized BaseExecutor object
|
||||
"""
|
||||
|
||||
# There will be 3 ways to return a executor.
|
||||
if executor is None:
|
||||
# 1) create executor with param `executor`
|
||||
logger.info("Create new executor ")
|
||||
from ..online.executor import SimulatorExecutor
|
||||
|
||||
executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose)
|
||||
elif isinstance(executor, (dict, str)):
|
||||
# 2) create executor with config
|
||||
logger.info("Create new executor ")
|
||||
executor = init_instance_by_config(executor)
|
||||
|
||||
from ..online.executor import BaseExecutor
|
||||
|
||||
# 3) Use the executor directly
|
||||
if not isinstance(executor, BaseExecutor):
|
||||
raise TypeError("Executor not supported")
|
||||
return executor
|
||||
|
||||
|
||||
# This is the API for compatibility for legacy code
|
||||
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs):
|
||||
"""This function will help you set a reasonable Exchange and provide default value for strategy
|
||||
Parameters
|
||||
----------
|
||||
|
||||
- **backtest workflow related or commmon arguments**
|
||||
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column.
|
||||
account : float
|
||||
init account value.
|
||||
shift : int
|
||||
whether to shift prediction by one day.
|
||||
benchmark : str
|
||||
benchmark code, default is SH000905 CSI 500.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
return_order : bool
|
||||
whether to return order list
|
||||
|
||||
- **strategy related arguments**
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
- **exchange related arguments**
|
||||
|
||||
exchange: Exchange()
|
||||
pass the exchange for speeding up.
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost. The default value is 0.002(0.2%).
|
||||
close_cost : float
|
||||
close transaction cost. The default value is 0.002(0.2%).
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
|
||||
- **executor related arguments**
|
||||
|
||||
executor : BaseExecutor()
|
||||
executor used in backtest.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
"""
|
||||
# check strategy:
|
||||
spec = inspect.getfullargspec(get_strategy)
|
||||
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
strategy = get_strategy(**str_args)
|
||||
|
||||
# init exchange:
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
# init executor:
|
||||
executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose)
|
||||
|
||||
# run backtest
|
||||
report_dict = backtest_func(
|
||||
pred=pred,
|
||||
strategy=strategy,
|
||||
executor=executor,
|
||||
trade_exchange=trade_exchange,
|
||||
shift=shift,
|
||||
verbose=verbose,
|
||||
account=account,
|
||||
benchmark=benchmark,
|
||||
return_order=return_order,
|
||||
)
|
||||
# for compatibility of the old API. return the dict positions
|
||||
|
||||
positions = report_dict.get("positions")
|
||||
report_dict.update({"positions": {k: p.position for k, p in positions.items()}})
|
||||
return report_dict
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from ...utils import get_date_by_shift, get_date_range
|
||||
from ..online.executor import SimulatorExecutor
|
||||
from ...data import D
|
||||
from .account import Account
|
||||
from ...config import C
|
||||
@@ -15,7 +14,7 @@ from ...data.dataset.utils import get_level_index
|
||||
LOG = get_module_logger("backtest")
|
||||
|
||||
|
||||
def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark):
|
||||
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
|
||||
"""Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
@@ -70,8 +69,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
bench = _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean()
|
||||
|
||||
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
|
||||
executor = SimulatorExecutor(trade_exchange, verbose=verbose)
|
||||
|
||||
if return_order:
|
||||
multi_order_list = []
|
||||
# trading apart
|
||||
for pred_date, trade_date in zip(predict_dates, trade_dates):
|
||||
# for loop predict date and trading date
|
||||
@@ -103,6 +102,8 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
)
|
||||
else:
|
||||
order_list = []
|
||||
if return_order:
|
||||
multi_order_list.append((trade_account, order_list, trade_date))
|
||||
# 4. Get result after executing order list
|
||||
# NOTE: The following operation will modify order.amount.
|
||||
# NOTE: If it is buy and the cash is insufficient, the tradable amount will be recalculated
|
||||
@@ -115,7 +116,11 @@ def backtest(pred, strategy, trade_exchange, shift, verbose, account, benchmark)
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
report_df["bench"] = bench
|
||||
positions = trade_account.get_positions()
|
||||
return report_df, positions
|
||||
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
if return_order:
|
||||
report_dict.update({"order_list": multi_order_list})
|
||||
return report_dict
|
||||
|
||||
|
||||
def update_account(trade_account, trade_info, trade_exchange, trade_date):
|
||||
|
||||
@@ -6,17 +6,16 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import inspect
|
||||
import warnings
|
||||
from ..log import get_module_logger
|
||||
from . import strategy as strategy_pool
|
||||
from .strategy.strategy import BaseStrategy
|
||||
from .backtest.exchange import Exchange
|
||||
from .backtest.backtest import backtest as backtest_func, get_date_range
|
||||
from .backtest import get_exchange, backtest as backtest_func
|
||||
from .backtest.backtest import get_date_range
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
|
||||
@@ -46,144 +45,6 @@ def risk_analysis(r, N=252):
|
||||
return res
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy=None,
|
||||
topk=50,
|
||||
margin=0.5,
|
||||
n_drop=5,
|
||||
risk_degree=0.95,
|
||||
str_type="amount",
|
||||
adjust_dates=None,
|
||||
):
|
||||
"""get_strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Strategy
|
||||
an initialized strategy object
|
||||
"""
|
||||
if strategy is None:
|
||||
str_cls_dict = {
|
||||
"amount": "TopkAmountStrategy",
|
||||
"weight": "TopkWeightStrategy",
|
||||
"dropout": "TopkDropoutStrategy",
|
||||
}
|
||||
logger.info("Create new streategy ")
|
||||
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
|
||||
strategy = str_cls(
|
||||
topk=topk,
|
||||
buffer_margin=margin,
|
||||
n_drop=n_drop,
|
||||
risk_degree=risk_degree,
|
||||
adjust_dates=adjust_dates,
|
||||
)
|
||||
if not isinstance(strategy, BaseStrategy):
|
||||
raise TypeError("Strategy not supported")
|
||||
return strategy
|
||||
|
||||
|
||||
def get_exchange(
|
||||
pred,
|
||||
exchange=None,
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
extract_codes=False,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost.
|
||||
close_cost : float
|
||||
close transaction cost.
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
NOTE: This will be faster with offline qlib.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Exchange
|
||||
an initialized Exchange object
|
||||
"""
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
limit_threshold = C.limit_threshold
|
||||
if deal_price is None:
|
||||
deal_price = C.deal_price
|
||||
if exchange is None:
|
||||
logger.info("Create new exchange")
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
if extract_codes:
|
||||
codes = sorted(pred.index.get_level_values("instrument").unique())
|
||||
else:
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values("datetime").unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
trade_dates=dates,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
min_cost=min_cost,
|
||||
trade_unit=trade_unit,
|
||||
)
|
||||
return exchange
|
||||
|
||||
|
||||
# This is the API for compatibility for legacy code
|
||||
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs):
|
||||
"""This function will help you set a reasonable Exchange and provide default value for strategy
|
||||
@@ -249,30 +110,22 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
|
||||
- **executor related arguments**
|
||||
|
||||
executor : BaseExecutor()
|
||||
executor used in backtest.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
"""
|
||||
# check strategy:
|
||||
spec = inspect.getfullargspec(get_strategy)
|
||||
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
strategy = get_strategy(**str_args)
|
||||
|
||||
# init exchange:
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
# run backtest
|
||||
report_df, positions = backtest_func(
|
||||
pred=pred,
|
||||
strategy=strategy,
|
||||
trade_exchange=trade_exchange,
|
||||
shift=shift,
|
||||
verbose=verbose,
|
||||
account=account,
|
||||
benchmark=benchmark,
|
||||
warnings.warn(
|
||||
"this function is deprecated, please use backtest function in qlib.contrib.backtest", DeprecationWarning
|
||||
)
|
||||
# for compatibility of the old API. return the dict positions
|
||||
positions = {k: p.position for k, p in positions.items()}
|
||||
return report_df, positions
|
||||
report_dict = backtest_func(
|
||||
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
|
||||
)
|
||||
return report_dict.get("report_df"), report_dict.get("positions")
|
||||
|
||||
|
||||
def long_short_backtest(
|
||||
|
||||
@@ -204,8 +204,8 @@ class ALSTM(Model):
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -260,7 +260,7 @@ class ALSTM(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.ALSTM_model.eval()
|
||||
|
||||
@@ -249,8 +249,8 @@ class GATs(Model):
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -332,7 +332,7 @@ class GATs(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
sampler_test = DailyBatchSampler(dl_test)
|
||||
test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)
|
||||
|
||||
@@ -204,8 +204,8 @@ class GRU(Model):
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -260,7 +260,7 @@ class GRU(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.GRU_model.eval()
|
||||
|
||||
@@ -204,8 +204,8 @@ class LSTM(Model):
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -260,7 +260,7 @@ class LSTM(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.LSTM_model.eval()
|
||||
|
||||
@@ -259,7 +259,7 @@ class DNNModelPytorch(Model):
|
||||
loss = torch.mul(sqr_loss, w).mean()
|
||||
return loss
|
||||
elif loss_type == "binary":
|
||||
loss = nn.BCELoss()
|
||||
loss = nn.BCELoss(weight=w)
|
||||
return loss(pred, target)
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
642
qlib/contrib/model/pytorch_tabnet.py
Normal file
642
qlib/contrib/model/pytorch_tabnet.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TabnetModel(Model):
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=158,
|
||||
out_dim=64,
|
||||
final_out_dim=1,
|
||||
batch_size=4096,
|
||||
n_d=64,
|
||||
n_a=64,
|
||||
n_shared=2,
|
||||
n_ind=2,
|
||||
n_steps=5,
|
||||
n_epochs=100,
|
||||
pretrain_n_epochs=50,
|
||||
relax=1.3,
|
||||
vbs=2048,
|
||||
seed=993,
|
||||
optimizer="adam",
|
||||
loss="mse",
|
||||
metric="",
|
||||
early_stop=20,
|
||||
GPU="1",
|
||||
pretrain_loss="custom",
|
||||
ps=0.3,
|
||||
lr=0.01,
|
||||
pretrain=True,
|
||||
pretrain_file="./pretrain/best.model",
|
||||
):
|
||||
"""
|
||||
TabNet model for Qlib
|
||||
|
||||
Args:
|
||||
ps: probability to generate the bernoulli mask
|
||||
"""
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.out_dim = out_dim
|
||||
self.final_out_dim = final_out_dim
|
||||
self.lr = lr
|
||||
self.batch_size = batch_size
|
||||
self.optimizer = optimizer.lower()
|
||||
self.pretrain_loss = pretrain_loss
|
||||
self.seed = seed
|
||||
self.ps = ps
|
||||
self.n_epochs = n_epochs
|
||||
self.logger = get_module_logger("TabNet")
|
||||
self.pretrain_n_epochs = pretrain_n_epochs
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() else "cpu"
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.pretrain = pretrain
|
||||
self.pretrain_file = pretrain_file
|
||||
self.logger.info(
|
||||
"TabNet:"
|
||||
"\nbatch_size : {}"
|
||||
"\nvirtual bs : {}"
|
||||
"\nGPU : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
|
||||
)
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.tabnet_model = TabNet(
|
||||
inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax, device=self.device
|
||||
).to(self.device)
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.pretrain_optimizer = optim.Adam(
|
||||
list(self.tabnet_model.parameters()) + list(self.tabnet_decoder.parameters()), lr=self.lr
|
||||
)
|
||||
self.train_optimizer = optim.Adam(self.tabnet_model.parameters(), lr=self.lr)
|
||||
|
||||
elif optimizer.lower() == "gd":
|
||||
self.pretrain_optimizer = optim.SGD(
|
||||
list(self.tabnet_model.parameters()) + list(self.tabnet_decoder.parameters()), lr=self.lr
|
||||
)
|
||||
self.train_optimizer = optim.SGD(self.tabnet_model.parameters(), lr=self.lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
|
||||
# make a directory if pretrian director does not exist
|
||||
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
|
||||
self.logger.info("make folder to store model...")
|
||||
os.makedirs("pretrain")
|
||||
|
||||
[df_train, df_valid] = dataset.prepare(
|
||||
["pretrain", "pretrain_validation"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
df_valid.fillna(df_valid.mean(), inplace=True)
|
||||
|
||||
x_train = df_train["feature"]
|
||||
x_valid = df_valid["feature"]
|
||||
|
||||
# Early stop setup
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
|
||||
for epoch_idx in range(self.pretrain_n_epochs):
|
||||
self.logger.info("epoch: %s" % (epoch_idx))
|
||||
self.logger.info("pre-training...")
|
||||
self.pretrain_epoch(x_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss = self.pretrain_test_epoch(x_train)
|
||||
valid_loss = self.pretrain_test_epoch(x_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_loss, valid_loss))
|
||||
|
||||
if valid_loss < best_loss:
|
||||
self.logger.info("Save Model...")
|
||||
torch.save(self.tabnet_model.state_dict(), pretrain_file)
|
||||
best_loss = valid_loss
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
if self.pretrain:
|
||||
# there is a pretrained model, load the model
|
||||
self.logger.info("Pretrain...")
|
||||
self.pretrain_fn(dataset, self.pretrain_file)
|
||||
self.logger.info("Load Pretrain model")
|
||||
self.tabnet_model.load_state_dict(torch.load(self.pretrain_file))
|
||||
|
||||
# adding one more linear layer to fit the final output dimension
|
||||
self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device)
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
|
||||
for epoch_idx in range(self.n_epochs):
|
||||
self.logger.info("epoch: %s" % (epoch_idx))
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(x_train, y_train)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(x_train, y_train)
|
||||
valid_loss, val_score = self.test_epoch(x_valid, y_valid)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score < best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = epoch_idx
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.tabnet_model.eval()
|
||||
x_values = torch.from_numpy(x_test.values)
|
||||
x_values[torch.isnan(x_values)] = 0
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = x_values[begin:end].float().to(self.device)
|
||||
priors = torch.ones(end - begin, self.d_feat).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.tabnet_model(x_batch, priors).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
# prepare training data
|
||||
x_values = torch.from_numpy(data_x.values)
|
||||
y_values = torch.from_numpy(np.squeeze(data_y.values))
|
||||
x_values[torch.isnan(x_values)] = 0
|
||||
y_values[torch.isnan(y_values)] = 0
|
||||
self.tabnet_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
feature = x_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
label = y_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
x_train_values = torch.from_numpy(x_train.values)
|
||||
y_train_values = torch.from_numpy(np.squeeze(y_train.values))
|
||||
x_train_values[torch.isnan(x_train_values)] = 0
|
||||
y_train_values[torch.isnan(y_train_values)] = 0
|
||||
self.tabnet_model.train()
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = x_train_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
label = y_train_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.tabnet_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def pretrain_epoch(self, x_train):
|
||||
train_set = torch.from_numpy(x_train.values)
|
||||
train_set[torch.isnan(train_set)] = 0
|
||||
indices = np.arange(len(train_set))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
self.tabnet_model.train()
|
||||
self.tabnet_decoder.train()
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
S_mask = torch.bernoulli(torch.empty(self.batch_size, self.d_feat).fill_(self.ps))
|
||||
x_train_values = train_set[indices[i : i + self.batch_size]] * (1 - S_mask)
|
||||
y_train_values = train_set[indices[i : i + self.batch_size]] * (S_mask)
|
||||
|
||||
S_mask = S_mask.to(self.device)
|
||||
feature = x_train_values.float().to(self.device)
|
||||
label = y_train_values.float().to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
|
||||
self.pretrain_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.pretrain_optimizer.step()
|
||||
|
||||
def pretrain_test_epoch(self, x_train):
|
||||
train_set = torch.from_numpy(x_train.values)
|
||||
train_set[torch.isnan(train_set)] = 0
|
||||
indices = np.arange(len(train_set))
|
||||
|
||||
self.tabnet_model.eval()
|
||||
self.tabnet_decoder.eval()
|
||||
|
||||
losses = []
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
S_mask = torch.bernoulli(torch.empty(self.batch_size, self.d_feat).fill_(self.ps))
|
||||
x_train_values = train_set[indices[i : i + self.batch_size]] * (1 - S_mask)
|
||||
y_train_values = train_set[indices[i : i + self.batch_size]] * (S_mask)
|
||||
|
||||
feature = x_train_values.float().to(self.device)
|
||||
label = y_train_values.float().to(self.device)
|
||||
S_mask = S_mask.to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
def pretrain_loss_fn(self, f_hat, f, S):
|
||||
"""
|
||||
Pretrain loss function defined in the original paper, read "Tabular self-supervised learning" in https://arxiv.org/pdf/1908.07442.pdf
|
||||
"""
|
||||
down_mean = torch.mean(f, dim=0)
|
||||
down = torch.sqrt(torch.sum(torch.square(f - down_mean), dim=0))
|
||||
up = (f_hat - f) * S
|
||||
return torch.sum(torch.square(up / down))
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
mask = ~torch.isnan(label)
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask])
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
|
||||
class FinetuneModel(nn.Module):
|
||||
"""
|
||||
FinuetuneModel for adding a layer by the end
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, trained_model):
|
||||
super().__init__()
|
||||
self.model = trained_model
|
||||
self.fc = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x, priors):
|
||||
return self.fc(self.model(x, priors)[0]).squeeze() # take the vec out
|
||||
|
||||
|
||||
class DecoderStep(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
super().__init__()
|
||||
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs, device)
|
||||
self.fc = nn.Linear(out_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fea_tran(x)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class TabNet_Decoder(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps, device):
|
||||
"""
|
||||
TabNet decoder that is used in pre-training
|
||||
"""
|
||||
self.out_dim = out_dim
|
||||
|
||||
super().__init__()
|
||||
if n_shared > 0:
|
||||
self.shared = nn.ModuleList()
|
||||
self.shared.append(nn.Linear(inp_dim, 2 * out_dim))
|
||||
for x in range(n_shared - 1):
|
||||
self.shared.append(nn.Linear(out_dim, 2 * out_dim)) # preset the linear function we will use
|
||||
else:
|
||||
self.shared = None
|
||||
self.n_steps = n_steps
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps):
|
||||
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs, device))
|
||||
|
||||
def forward(self, x):
|
||||
out = torch.zeros(x.size(0), self.out_dim).to(x.device)
|
||||
for step in self.steps:
|
||||
out += step(x)
|
||||
return out
|
||||
|
||||
|
||||
class TabNet(nn.Module):
|
||||
def __init__(
|
||||
self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024, device="cpu"
|
||||
):
|
||||
"""
|
||||
TabNet AKA the original encoder
|
||||
|
||||
Args:
|
||||
n_d: dimension of the features used to calculate the final results
|
||||
n_a: dimension of the features input to the attention transformer of the next step
|
||||
n_shared: numbr of shared steps in feature transfomer(optional)
|
||||
n_ind: number of independent steps in feature transformer
|
||||
n_steps: number of steps of pass through tabbet
|
||||
relax coefficient:
|
||||
virtual batch size:
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# set the number of shared step in feature transformer
|
||||
if n_shared > 0:
|
||||
self.shared = nn.ModuleList()
|
||||
self.shared.append(nn.Linear(inp_dim, 2 * (n_d + n_a)))
|
||||
for x in range(n_shared - 1):
|
||||
self.shared.append(nn.Linear(n_d + n_a, 2 * (n_d + n_a))) # preset the linear function we will use
|
||||
else:
|
||||
self.shared = None
|
||||
|
||||
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs, device)
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps - 1):
|
||||
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs, device))
|
||||
self.fc = nn.Linear(n_d, out_dim)
|
||||
self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01)
|
||||
self.n_d = n_d
|
||||
|
||||
def forward(self, x, priors):
|
||||
assert not torch.isnan(x).any()
|
||||
x = self.bn(x)
|
||||
x_a = self.first_step(x)[:, self.n_d :]
|
||||
sparse_loss = torch.zeros(1).to(x.device)
|
||||
out = torch.zeros(x.size(0), self.n_d).to(x.device)
|
||||
for step in self.steps:
|
||||
x_te, l = step(x, x_a, priors)
|
||||
out += F.relu(x_te[:, : self.n_d]) # split the feautre from feat_transformer
|
||||
x_a = x_te[:, self.n_d :]
|
||||
sparse_loss += l
|
||||
return self.fc(out), sparse_loss
|
||||
|
||||
|
||||
class GBN(nn.Module):
|
||||
"""
|
||||
Ghost Batch Normalization
|
||||
an efficient way of doing batch normalization
|
||||
|
||||
Args:
|
||||
vbs: virtual batch size
|
||||
"""
|
||||
|
||||
def __init__(self, inp, vbs=1024, momentum=0.01):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm1d(inp, momentum=momentum)
|
||||
self.vbs = vbs
|
||||
|
||||
def forward(self, x):
|
||||
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
|
||||
res = [self.bn(y) for y in chunk]
|
||||
return torch.cat(res, 0)
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
"""
|
||||
GLU block that extracts only the most essential information
|
||||
|
||||
Args:
|
||||
vbs: virtual batch size
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, out_dim, fc=None, vbs=1024):
|
||||
super().__init__()
|
||||
if fc:
|
||||
self.fc = fc
|
||||
else:
|
||||
self.fc = nn.Linear(inp_dim, out_dim * 2)
|
||||
self.bn = GBN(out_dim * 2, vbs=vbs)
|
||||
self.od = out_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = self.bn(self.fc(x))
|
||||
return torch.mul(x[:, : self.od], torch.sigmoid(x[:, self.od :]))
|
||||
|
||||
|
||||
class AttentionTransformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
relax: relax coefficient. The greater it is, we can
|
||||
use the same features more. When it is set to 1
|
||||
we can use every feature only once
|
||||
"""
|
||||
|
||||
def __init__(self, d_a, inp_dim, relax, vbs=1024):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(d_a, inp_dim)
|
||||
self.bn = GBN(inp_dim, vbs=vbs)
|
||||
self.r = relax
|
||||
|
||||
# a:feature from previous decision step
|
||||
def forward(self, a, priors):
|
||||
a = self.bn(self.fc(a))
|
||||
mask = SparsemaxFunction.apply(a * priors)
|
||||
priors = priors * (self.r - mask) # updating the prior
|
||||
return mask
|
||||
|
||||
|
||||
class FeatureTransformer(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
super().__init__()
|
||||
first = True
|
||||
self.shared = nn.ModuleList()
|
||||
if shared:
|
||||
self.shared.append(GLU(inp_dim, out_dim, shared[0], vbs=vbs))
|
||||
first = False
|
||||
for fc in shared[1:]:
|
||||
self.shared.append(GLU(out_dim, out_dim, fc, vbs=vbs))
|
||||
else:
|
||||
self.shared = None
|
||||
self.independ = nn.ModuleList()
|
||||
if first:
|
||||
self.independ.append(GLU(inp, out_dim, vbs=vbs))
|
||||
for x in range(first, n_ind):
|
||||
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
|
||||
self.scale = torch.sqrt(torch.tensor([0.5], device=device))
|
||||
|
||||
def forward(self, x):
|
||||
if self.shared:
|
||||
x = self.shared[0](x)
|
||||
for glu in self.shared[1:]:
|
||||
x = torch.add(x, glu(x))
|
||||
x = x * self.scale
|
||||
for glu in self.independ:
|
||||
x = torch.add(x, glu(x))
|
||||
x = x * self.scale
|
||||
return x
|
||||
|
||||
|
||||
class DecisionStep(nn.Module):
|
||||
"""
|
||||
One step for the TabNet
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs, device):
|
||||
super().__init__()
|
||||
self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs, device)
|
||||
|
||||
def forward(self, x, a, priors):
|
||||
mask = self.atten_tran(a, priors)
|
||||
sparse_loss = ((-1) * mask * torch.log(mask + 1e-10)).mean()
|
||||
x = self.fea_tran(x * mask)
|
||||
return x, sparse_loss
|
||||
|
||||
|
||||
def make_ix_like(input, dim=0):
|
||||
d = input.size(dim)
|
||||
rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)
|
||||
view = [1] * input.dim()
|
||||
view[0] = -1
|
||||
return rho.view(view).transpose(0, dim)
|
||||
|
||||
|
||||
class SparsemaxFunction(Function):
|
||||
"""
|
||||
SparseMax function for replacing reLU
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, dim=-1):
|
||||
ctx.dim = dim
|
||||
max_val, _ = input.max(dim=dim, keepdim=True)
|
||||
input -= max_val # same numerical stability trick as for softmax
|
||||
tau, supp_size = SparsemaxFunction.threshold_and_support(input, dim=dim)
|
||||
output = torch.clamp(input - tau, min=0)
|
||||
ctx.save_for_backward(supp_size, output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
supp_size, output = ctx.saved_tensors
|
||||
dim = ctx.dim
|
||||
grad_input = grad_output.clone()
|
||||
grad_input[output == 0] = 0
|
||||
|
||||
v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze()
|
||||
v_hat = v_hat.unsqueeze(dim)
|
||||
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
|
||||
return grad_input, None
|
||||
|
||||
@staticmethod
|
||||
def threshold_and_support(input, dim=-1):
|
||||
input_srt, _ = torch.sort(input, descending=True, dim=dim)
|
||||
input_cumsum = input_srt.cumsum(dim) - 1
|
||||
rhos = make_ix_like(input, dim)
|
||||
support = rhos * input_srt > input_cumsum
|
||||
|
||||
support_size = support.sum(dim=dim).unsqueeze(dim)
|
||||
tau = input_cumsum.gather(dim, support_size - 1)
|
||||
tau /= support_size.to(input.dtype)
|
||||
return tau, support_size
|
||||
@@ -24,7 +24,12 @@ from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
init_instance_by_config,
|
||||
register_wrapper,
|
||||
get_module_by_module_path,
|
||||
)
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
@@ -1026,12 +1031,31 @@ class ClientProvider(BaseProvider):
|
||||
DatasetD.set_conn(self.client)
|
||||
|
||||
|
||||
Cal = Wrapper()
|
||||
Inst = Wrapper()
|
||||
FeatureD = Wrapper()
|
||||
ExpressionD = Wrapper()
|
||||
DatasetD = Wrapper()
|
||||
D = Wrapper()
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
|
||||
CalendarProviderWrapper = Annotated[CalendarProvider, Wrapper]
|
||||
InstrumentProviderWrapper = Annotated[InstrumentProvider, Wrapper]
|
||||
FeatureProviderWrapper = Annotated[FeatureProvider, Wrapper]
|
||||
ExpressionProviderWrapper = Annotated[ExpressionProvider, Wrapper]
|
||||
DatasetProviderWrapper = Annotated[DatasetProvider, Wrapper]
|
||||
BaseProviderWrapper = Annotated[BaseProvider, Wrapper]
|
||||
else:
|
||||
CalendarProviderWrapper = CalendarProvider
|
||||
InstrumentProviderWrapper = InstrumentProvider
|
||||
FeatureProviderWrapper = FeatureProvider
|
||||
ExpressionProviderWrapper = ExpressionProvider
|
||||
DatasetProviderWrapper = DatasetProvider
|
||||
BaseProviderWrapper = BaseProvider
|
||||
|
||||
Cal: CalendarProviderWrapper = Wrapper()
|
||||
Inst: InstrumentProviderWrapper = Wrapper()
|
||||
FeatureD: FeatureProviderWrapper = Wrapper()
|
||||
ExpressionD: ExpressionProviderWrapper = Wrapper()
|
||||
DatasetD: DatasetProviderWrapper = Wrapper()
|
||||
D: BaseProviderWrapper = Wrapper()
|
||||
|
||||
|
||||
def register_all_wrappers():
|
||||
|
||||
@@ -49,20 +49,20 @@ class GetData:
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
chuck_size = 1024
|
||||
chunk_size = 1024
|
||||
logger.warning(
|
||||
f"The data for the example is collected from Yahoo Finance. Please be aware that the quality of the data might not be perfect. (You can refer to the original data source: https://finance.yahoo.com/lookup.)"
|
||||
)
|
||||
logger.info(f"{file_name} downloading......")
|
||||
with tqdm(total=int(resp.headers.get("Content-Length", 0))) as p_bar:
|
||||
with target_path.open("wb") as fp:
|
||||
for chuck in resp.iter_content(chunk_size=chuck_size):
|
||||
fp.write(chuck)
|
||||
p_bar.update(chuck_size)
|
||||
for chunk in resp.iter_content(chunk_size=chunk_size):
|
||||
fp.write(chunk)
|
||||
p_bar.update(chunk_size)
|
||||
|
||||
self._unzip(target_path, target_dir, delete_old)
|
||||
if self.delete_zip_file:
|
||||
target_path.unlike()
|
||||
target_path.unlink()
|
||||
|
||||
def check_dataset(self, file_name: str, dataset_version: str = None):
|
||||
url = self.merge_remote_url(file_name, dataset_version)
|
||||
|
||||
@@ -27,11 +27,6 @@ class Serializable:
|
||||
def dump_all(self):
|
||||
"""
|
||||
will the object dump all object
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
"""
|
||||
return getattr(self, "_dump_all", False)
|
||||
|
||||
@@ -39,11 +34,6 @@ class Serializable:
|
||||
def exclude(self):
|
||||
"""
|
||||
What attribute will be dumped
|
||||
|
||||
Parameters
|
||||
----------
|
||||
self : [TODO:type]
|
||||
[TODO:description]
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
|
||||
@@ -461,5 +461,14 @@ class QlibRecorder:
|
||||
self.get_exp().get_recorder().set_tags(**kwargs)
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper]
|
||||
else:
|
||||
QlibRecorderWrapper = QlibRecorder
|
||||
|
||||
# global record
|
||||
R = Wrapper()
|
||||
R: QlibRecorderWrapper = Wrapper()
|
||||
|
||||
@@ -44,7 +44,7 @@ def sys_config(config, config_path):
|
||||
# worflow handler function
|
||||
def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"):
|
||||
with open(config_path) as fp:
|
||||
config = yaml.load(fp, Loader=yaml.Loader)
|
||||
config = yaml.load(fp, Loader=yaml.SafeLoader)
|
||||
|
||||
# config the `sys` section
|
||||
sys_config(config, config_path)
|
||||
|
||||
@@ -65,13 +65,13 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end` method.")
|
||||
|
||||
def create_recorder(self, name=None):
|
||||
def create_recorder(self, recorder_name=None):
|
||||
"""
|
||||
Create a recorder for each experiment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
recorder_name : str
|
||||
the name of the recorder to be created.
|
||||
|
||||
Returns
|
||||
|
||||
@@ -5,10 +5,9 @@ import re
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from ..contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from ..contrib.evaluate import risk_analysis
|
||||
from ..contrib.backtest import backtest as normal_backtest
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
@@ -241,9 +240,14 @@ class PortAnaRecord(SignalRecord):
|
||||
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
order_normal = report_dict.get("order_list")
|
||||
if order_normal:
|
||||
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
|
||||
@@ -33,7 +33,6 @@ class InfoCollector:
|
||||
"scipy",
|
||||
"requests",
|
||||
"sacred",
|
||||
"pymongo",
|
||||
"python-socketio",
|
||||
"redis",
|
||||
"python-redis-lock",
|
||||
|
||||
Reference in New Issue
Block a user