mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
Add TabNet config
This commit is contained in:
@@ -197,7 +197,8 @@ Here is a list of models built on `Qlib`.
|
||||
- [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py)
|
||||
- [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py)
|
||||
- [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py)
|
||||
- [TFT based on tensorflow-1.15.0](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch](qlib/contrib/model/tabnet.py)
|
||||
<!-- - [TFT based on tensorflow](examples/benchmarks/TFT/tft.py) -->
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ task:
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360_Denoise
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
|
||||
@@ -44,7 +44,7 @@ task:
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: ALPHA360_Denoise
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
|
||||
5
examples/benchmarks/TabNet/requirements.txt
Normal file
5
examples/benchmarks/TabNet/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
pytorch-tabnet==2.0.1
|
||||
66
examples/benchmarks/TabNet/workflow_config_tabnet.yaml
Normal file
66
examples/benchmarks/TabNet/workflow_config_tabnet.yaml
Normal file
@@ -0,0 +1,66 @@
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabNetModel
|
||||
module_path: qlib.contrib.model.tabnet
|
||||
kwargs:
|
||||
n_d: 8
|
||||
n_a: 8
|
||||
n_steps: 3
|
||||
gamma: 1.3
|
||||
n_independent: 2
|
||||
n_shared: 2
|
||||
seed: 0
|
||||
momentum: 0.02
|
||||
lambda_sparse: 1e-3
|
||||
optimizer_params: {lr: 2e-3}
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -71,7 +71,7 @@ if __name__ == "__main__":
|
||||
"seed": 0,
|
||||
"momentum": 0.02,
|
||||
"lambda_sparse": 1e-3,
|
||||
"optimizer_params": {'lr':2e-3}
|
||||
"optimizer_params": {"lr": 2e-3},
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
|
||||
@@ -9,19 +9,24 @@ from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TabNetModel(Model):
|
||||
"""TabNetModel Model"""
|
||||
|
||||
def __init__(self, n_d, n_a,
|
||||
n_steps,
|
||||
gamma,
|
||||
n_independent,
|
||||
n_shared,
|
||||
seed,
|
||||
momentum,
|
||||
lambda_sparse,
|
||||
optimizer_params,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
n_d,
|
||||
n_a,
|
||||
n_steps,
|
||||
gamma,
|
||||
n_independent,
|
||||
n_shared,
|
||||
seed,
|
||||
momentum,
|
||||
lambda_sparse,
|
||||
optimizer_params,
|
||||
**kwargs
|
||||
):
|
||||
self.model = None
|
||||
|
||||
self.n_d = n_d
|
||||
@@ -47,28 +52,28 @@ class TabNetModel(Model):
|
||||
seed=0,
|
||||
momentum=0.02,
|
||||
lambda_sparse=1e-3,
|
||||
optimizer_params={'lr':2e-3},
|
||||
optimizer_params={"lr": 2e-3},
|
||||
**kwargs
|
||||
):
|
||||
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"].values, df_train["label"].values*100
|
||||
x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values*100
|
||||
x_train, y_train = df_train["feature"].values, df_train["label"].values * 100
|
||||
x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values * 100
|
||||
|
||||
self.model = TabNetRegressor(
|
||||
n_d=self.n_d,
|
||||
n_a=self.n_a,
|
||||
n_steps=self.n_steps,
|
||||
gamma=self.gamma,
|
||||
n_independent=self.n_independent,
|
||||
n_shared=self.n_shared,
|
||||
seed=self.seed,
|
||||
momentum=self.momentum,
|
||||
lambda_sparse=self.lambda_sparse,
|
||||
optimizer_params=self.optimizer_params,
|
||||
**kwargs
|
||||
n_d=self.n_d,
|
||||
n_a=self.n_a,
|
||||
n_steps=self.n_steps,
|
||||
gamma=self.gamma,
|
||||
n_independent=self.n_independent,
|
||||
n_shared=self.n_shared,
|
||||
seed=self.seed,
|
||||
momentum=self.momentum,
|
||||
lambda_sparse=self.lambda_sparse,
|
||||
optimizer_params=self.optimizer_params,
|
||||
**kwargs
|
||||
)
|
||||
self.model.fit(x_train, y_train, eval_set=[(x_valid, y_valid)])
|
||||
|
||||
|
||||
@@ -25,7 +25,9 @@ class BaseStrategy:
|
||||
return 0.95
|
||||
|
||||
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
|
||||
"""Parameter
|
||||
"""
|
||||
Parameters:
|
||||
-----------
|
||||
score_series : pd.Seires
|
||||
stock_id , score
|
||||
current : Position()
|
||||
@@ -44,8 +46,8 @@ class BaseStrategy:
|
||||
|
||||
def update(self, score_series, pred_date, trade_date):
|
||||
"""User can use this method to update strategy state each trade date.
|
||||
Parameter
|
||||
---------
|
||||
Parameters:
|
||||
-----------
|
||||
score_series : pd.Series
|
||||
stock_id , score
|
||||
pred_date : pd.Timestamp
|
||||
@@ -97,7 +99,7 @@ class AdjustTimer:
|
||||
Responsible for timing of position adjusting
|
||||
|
||||
This is designed as multiple inheritance mechanism due to
|
||||
- the is_adjust may need access to the internel state of a strategyw
|
||||
- the is_adjust may need access to the internel state of a strategy
|
||||
- it can be reguard as a enhancement to the existing strategy
|
||||
"""
|
||||
|
||||
@@ -139,7 +141,7 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
|
||||
def generate_target_weight_position(self, score, current, trade_date):
|
||||
"""
|
||||
Parameters:
|
||||
---------
|
||||
-----------
|
||||
score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column
|
||||
current : current position, use Position() class
|
||||
trade_exchange : Exchange()
|
||||
@@ -228,7 +230,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
Gnererate order list according to score_series at trade_date, will not change current.
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
-----------
|
||||
score_series : pd.Series
|
||||
stock_id , score
|
||||
current : Position()
|
||||
|
||||
Reference in New Issue
Block a user