1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-06 04:20:57 +08:00

Add TabNet config

This commit is contained in:
Jactus
2020-11-25 11:16:01 +08:00
parent 0e2c2fcd7f
commit 991c6195bd
9 changed files with 114 additions and 34 deletions

View File

@@ -197,7 +197,8 @@ Here is a list of models built on `Qlib`.
- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
- [TFT based on tensorflow-1.15.0](examples/benchmarks/TFT/tft.py)
- [TabNet based on pytorch](qlib/contrib/model/tabnet.py)
<!-- - [TFT based on tensorflow](examples/benchmarks/TFT/tft.py) -->
Your PR of new Quant models is highly welcomed.

View File

@@ -37,7 +37,7 @@ task:
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360_Denoise
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:

View File

@@ -44,7 +44,7 @@ task:
module_path: qlib.data.dataset
kwargs:
handler:
class: ALPHA360_Denoise
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:

View File

@@ -0,0 +1,5 @@
pandas==1.1.2
numpy==1.17.4
scikit_learn==0.23.2
torch==1.7.0
pytorch-tabnet==2.0.1

View File

@@ -0,0 +1,66 @@
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: TabNetModel
module_path: qlib.contrib.model.tabnet
kwargs:
n_d: 8
n_a: 8
n_steps: 3
gamma: 1.3
n_independent: 2
n_shared: 2
seed: 0
momentum: 0.02
lambda_sparse: 1e-3
optimizer_params: {lr: 2e-3}
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -71,7 +71,7 @@ if __name__ == "__main__":
"seed": 0,
"momentum": 0.02,
"lambda_sparse": 1e-3,
"optimizer_params": {'lr':2e-3}
"optimizer_params": {"lr": 2e-3},
},
},
"dataset": {

View File

@@ -9,19 +9,24 @@ from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
class TabNetModel(Model):
"""TabNetModel Model"""
def __init__(self, n_d, n_a,
n_steps,
gamma,
n_independent,
n_shared,
seed,
momentum,
lambda_sparse,
optimizer_params,
**kwargs):
def __init__(
self,
n_d,
n_a,
n_steps,
gamma,
n_independent,
n_shared,
seed,
momentum,
lambda_sparse,
optimizer_params,
**kwargs
):
self.model = None
self.n_d = n_d
@@ -47,28 +52,28 @@ class TabNetModel(Model):
seed=0,
momentum=0.02,
lambda_sparse=1e-3,
optimizer_params={'lr':2e-3},
optimizer_params={"lr": 2e-3},
**kwargs
):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"].values, df_train["label"].values*100
x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values*100
x_train, y_train = df_train["feature"].values, df_train["label"].values * 100
x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values * 100
self.model = TabNetRegressor(
n_d=self.n_d,
n_a=self.n_a,
n_steps=self.n_steps,
gamma=self.gamma,
n_independent=self.n_independent,
n_shared=self.n_shared,
seed=self.seed,
momentum=self.momentum,
lambda_sparse=self.lambda_sparse,
optimizer_params=self.optimizer_params,
**kwargs
n_d=self.n_d,
n_a=self.n_a,
n_steps=self.n_steps,
gamma=self.gamma,
n_independent=self.n_independent,
n_shared=self.n_shared,
seed=self.seed,
momentum=self.momentum,
lambda_sparse=self.lambda_sparse,
optimizer_params=self.optimizer_params,
**kwargs
)
self.model.fit(x_train, y_train, eval_set=[(x_valid, y_valid)])

View File

@@ -25,7 +25,9 @@ class BaseStrategy:
return 0.95
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
"""Parameter
"""
Parameters:
-----------
score_series : pd.Seires
stock_id , score
current : Position()
@@ -44,8 +46,8 @@ class BaseStrategy:
def update(self, score_series, pred_date, trade_date):
"""User can use this method to update strategy state each trade date.
Parameter
---------
Parameters:
-----------
score_series : pd.Series
stock_id , score
pred_date : pd.Timestamp
@@ -97,7 +99,7 @@ class AdjustTimer:
Responsible for timing of position adjusting
This is designed as multiple inheritance mechanism due to
- the is_adjust may need access to the internel state of a strategyw
- the is_adjust may need access to the internel state of a strategy
- it can be reguard as a enhancement to the existing strategy
"""
@@ -139,7 +141,7 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
def generate_target_weight_position(self, score, current, trade_date):
"""
Parameters:
---------
-----------
score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
current : current position, use Position() class
trade_exchange : Exchange()
@@ -228,7 +230,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
Gnererate order list according to score_series at trade_date, will not change current.
Parameters:
----------
-----------
score_series : pd.Series
stock_id , score
current : Position()

View File

@@ -58,6 +58,7 @@ REQUIRED = [
"joblib>=0.17.0",
"fire>=0.3.1",
"ruamel.yaml>=0.16.12",
"pytorch-tabnet>=2.0.1",
]
# Numpy include