From 991c6195bd897a8834741fc011cbd25e6a59b71f Mon Sep 17 00:00:00 2001 From: Jactus Date: Wed, 25 Nov 2020 11:16:01 +0800 Subject: [PATCH] Add TabNet config --- README.md | 3 +- .../CatBoost/workflow_config_catboost.yaml | 2 +- .../benchmarks/DNN/workflow_config_dnn.yaml | 2 +- examples/benchmarks/TabNet/requirements.txt | 5 ++ .../TabNet/workflow_config_tabnet.yaml | 66 +++++++++++++++++++ examples/workflow_by_code_tabnet.py | 2 +- qlib/contrib/model/tabnet.py | 53 ++++++++------- qlib/contrib/strategy/strategy.py | 14 ++-- setup.py | 1 + 9 files changed, 114 insertions(+), 34 deletions(-) create mode 100644 examples/benchmarks/TabNet/requirements.txt create mode 100644 examples/benchmarks/TabNet/workflow_config_tabnet.yaml diff --git a/README.md b/README.md index b06afd975..7c7e58a1c 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,8 @@ Here is a list of models built on `Qlib`. - [GRU based on pytorch](qlib/contrib/model/pytorch_gru.py) - [LSTM based on pytorcn](qlib/contrib/model/pytorch_lstm.py) - [GATs based on pytorch](qlib/contrib/model/pytorch_gats.py) -- [TFT based on tensorflow-1.15.0](examples/benchmarks/TFT/tft.py) +- [TabNet based on pytorch](qlib/contrib/model/tabnet.py) + Your PR of new Quant models is highly welcomed. diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost.yaml index 80229e22b..8bf3bb72b 100644 --- a/examples/benchmarks/CatBoost/workflow_config_catboost.yaml +++ b/examples/benchmarks/CatBoost/workflow_config_catboost.yaml @@ -37,7 +37,7 @@ task: module_path: qlib.data.dataset kwargs: handler: - class: ALPHA360_Denoise + class: Alpha158 module_path: qlib.contrib.data.handler kwargs: *data_handler_config segments: diff --git a/examples/benchmarks/DNN/workflow_config_dnn.yaml b/examples/benchmarks/DNN/workflow_config_dnn.yaml index 6dbd345dd..e853726ca 100644 --- a/examples/benchmarks/DNN/workflow_config_dnn.yaml +++ b/examples/benchmarks/DNN/workflow_config_dnn.yaml @@ -44,7 +44,7 @@ task: module_path: qlib.data.dataset kwargs: handler: - class: ALPHA360_Denoise + class: Alpha158 module_path: qlib.contrib.data.handler kwargs: *data_handler_config segments: diff --git a/examples/benchmarks/TabNet/requirements.txt b/examples/benchmarks/TabNet/requirements.txt new file mode 100644 index 000000000..244b74b19 --- /dev/null +++ b/examples/benchmarks/TabNet/requirements.txt @@ -0,0 +1,5 @@ +pandas==1.1.2 +numpy==1.17.4 +scikit_learn==0.23.2 +torch==1.7.0 +pytorch-tabnet==2.0.1 \ No newline at end of file diff --git a/examples/benchmarks/TabNet/workflow_config_tabnet.yaml b/examples/benchmarks/TabNet/workflow_config_tabnet.yaml new file mode 100644 index 000000000..0ee95f238 --- /dev/null +++ b/examples/benchmarks/TabNet/workflow_config_tabnet.yaml @@ -0,0 +1,66 @@ +provider_uri: "~/.qlib/qlib_data/cn_data" +region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy.strategy + kwargs: + topk: 50 + n_drop: 5 + backtest: + verbose: False + limit_threshold: 0.095 + account: 100000000 + benchmark: *benchmark + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: TabNetModel + module_path: qlib.contrib.model.tabnet + kwargs: + n_d: 8 + n_a: 8 + n_steps: 3 + gamma: 1.3 + n_independent: 2 + n_shared: 2 + seed: 0 + momentum: 0.02 + lambda_sparse: 1e-3 + optimizer_params: {lr: 2e-3} + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: {} + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config \ No newline at end of file diff --git a/examples/workflow_by_code_tabnet.py b/examples/workflow_by_code_tabnet.py index d275a875c..3778b9d59 100644 --- a/examples/workflow_by_code_tabnet.py +++ b/examples/workflow_by_code_tabnet.py @@ -71,7 +71,7 @@ if __name__ == "__main__": "seed": 0, "momentum": 0.02, "lambda_sparse": 1e-3, - "optimizer_params": {'lr':2e-3} + "optimizer_params": {"lr": 2e-3}, }, }, "dataset": { diff --git a/qlib/contrib/model/tabnet.py b/qlib/contrib/model/tabnet.py index 63a75d26f..bc13d1f62 100644 --- a/qlib/contrib/model/tabnet.py +++ b/qlib/contrib/model/tabnet.py @@ -9,19 +9,24 @@ from ...model.base import Model from ...data.dataset import DatasetH from ...data.dataset.handler import DataHandlerLP + class TabNetModel(Model): """TabNetModel Model""" - def __init__(self, n_d, n_a, - n_steps, - gamma, - n_independent, - n_shared, - seed, - momentum, - lambda_sparse, - optimizer_params, - **kwargs): + def __init__( + self, + n_d, + n_a, + n_steps, + gamma, + n_independent, + n_shared, + seed, + momentum, + lambda_sparse, + optimizer_params, + **kwargs + ): self.model = None self.n_d = n_d @@ -47,28 +52,28 @@ class TabNetModel(Model): seed=0, momentum=0.02, lambda_sparse=1e-3, - optimizer_params={'lr':2e-3}, + optimizer_params={"lr": 2e-3}, **kwargs ): df_train, df_valid = dataset.prepare( ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L ) - x_train, y_train = df_train["feature"].values, df_train["label"].values*100 - x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values*100 + x_train, y_train = df_train["feature"].values, df_train["label"].values * 100 + x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values * 100 self.model = TabNetRegressor( - n_d=self.n_d, - n_a=self.n_a, - n_steps=self.n_steps, - gamma=self.gamma, - n_independent=self.n_independent, - n_shared=self.n_shared, - seed=self.seed, - momentum=self.momentum, - lambda_sparse=self.lambda_sparse, - optimizer_params=self.optimizer_params, - **kwargs + n_d=self.n_d, + n_a=self.n_a, + n_steps=self.n_steps, + gamma=self.gamma, + n_independent=self.n_independent, + n_shared=self.n_shared, + seed=self.seed, + momentum=self.momentum, + lambda_sparse=self.lambda_sparse, + optimizer_params=self.optimizer_params, + **kwargs ) self.model.fit(x_train, y_train, eval_set=[(x_valid, y_valid)]) diff --git a/qlib/contrib/strategy/strategy.py b/qlib/contrib/strategy/strategy.py index 084737445..6eac9bafe 100644 --- a/qlib/contrib/strategy/strategy.py +++ b/qlib/contrib/strategy/strategy.py @@ -25,7 +25,9 @@ class BaseStrategy: return 0.95 def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date): - """Parameter + """ + Parameters: + ----------- score_series : pd.Seires stock_id , score current : Position() @@ -44,8 +46,8 @@ class BaseStrategy: def update(self, score_series, pred_date, trade_date): """User can use this method to update strategy state each trade date. - Parameter - --------- + Parameters: + ----------- score_series : pd.Series stock_id , score pred_date : pd.Timestamp @@ -97,7 +99,7 @@ class AdjustTimer: Responsible for timing of position adjusting This is designed as multiple inheritance mechanism due to - - the is_adjust may need access to the internel state of a strategyw + - the is_adjust may need access to the internel state of a strategy - it can be reguard as a enhancement to the existing strategy """ @@ -139,7 +141,7 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer): def generate_target_weight_position(self, score, current, trade_date): """ Parameters: - --------- + ----------- score : pred score for this trade date, pd.Series, index is stock_id, contain 'score' column current : current position, use Position() class trade_exchange : Exchange() @@ -228,7 +230,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer): Gnererate order list according to score_series at trade_date, will not change current. Parameters: - ---------- + ----------- score_series : pd.Series stock_id , score current : Position() diff --git a/setup.py b/setup.py index 2c9cfea95..4fe410b9d 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ REQUIRED = [ "joblib>=0.17.0", "fire>=0.3.1", "ruamel.yaml>=0.16.12", + "pytorch-tabnet>=2.0.1", ] # Numpy include