From 0e2c2fcd7f7ddac2e30a0d53e6caf59989e903d5 Mon Sep 17 00:00:00 2001 From: lwwang1995 Date: Wed, 25 Nov 2020 10:44:48 +0800 Subject: [PATCH] Add Tabnet. --- examples/workflow_by_code_tabnet.py | 142 ++++++++++++++++++++++++++++ qlib/contrib/model/tabnet.py | 80 ++++++++++++++++ 2 files changed, 222 insertions(+) create mode 100644 examples/workflow_by_code_tabnet.py create mode 100644 qlib/contrib/model/tabnet.py diff --git a/examples/workflow_by_code_tabnet.py b/examples/workflow_by_code_tabnet.py new file mode 100644 index 000000000..d275a875c --- /dev/null +++ b/examples/workflow_by_code_tabnet.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from pathlib import Path + +import qlib +import pandas as pd +from qlib.config import REG_CN +from qlib.contrib.model.tabnet import TabNetModel +from qlib.contrib.data.handler import ALPHA360_Denoise +from qlib.contrib.strategy.strategy import TopkDropoutStrategy +from qlib.contrib.evaluate import ( + backtest as normal_backtest, + risk_analysis, +) +from qlib.utils import exists_qlib_data + +# from qlib.model.learner import train_model +from qlib.utils import init_instance_by_config + +import pickle + +if __name__ == "__main__": + + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) + from get_data import GetData + + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + MARKET = "csi300" + BENCHMARK = "SH000300" + + ################################### + # train model + ################################### + DATA_HANDLER_CONFIG = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": MARKET, + } + + TRAINER_CONFIG = { + "train_start_time": "2008-01-01", + "train_end_time": "2014-12-31", + "validate_start_time": "2015-01-01", + "validate_end_time": "2016-12-31", + "test_start_time": "2017-01-01", + "test_end_time": "2020-08-01", + } + + task = { + "model": { + "class": "TabNetModel", + "module_path": "qlib.contrib.model.tabnet", + "kwargs": { + "n_d": 8, + "n_a": 8, + "n_steps": 3, + "gamma": 1.3, + "n_independent": 2, + "n_shared": 2, + "seed": 0, + "momentum": 0.02, + "lambda_sparse": 1e-3, + "optimizer_params": {'lr':2e-3} + }, + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "ALPHA360_Denoise", + "module_path": "qlib.contrib.data.handler", + "kwargs": DATA_HANDLER_CONFIG, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + } + # You shoud record the data in specific sequence + # "record": ['SignalRecord', 'SigAnaRecord', 'PortAnaRecord'], + } + + # model = train_model(task) + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) + model.fit(dataset) + + pred_score = model.predict(dataset) + + # save pred_score to file + pred_score_path = Path("~/tmp/qlib/pred_score.pkl").expanduser() + pred_score_path.parent.mkdir(exist_ok=True, parents=True) + pred_score.to_pickle(pred_score_path) + + ################################### + # backtest + ################################### + STRATEGY_CONFIG = { + "topk": 50, + "n_drop": 5, + } + BACKTEST_CONFIG = { + "verbose": False, + "limit_threshold": 0.095, + "account": 100000000, + "benchmark": BENCHMARK, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + } + + # use default strategy + # custom Strategy, refer to: TODO: Strategy API url + strategy = TopkDropoutStrategy(**STRATEGY_CONFIG) + report_normal, positions_normal = normal_backtest(pred_score, strategy=strategy, **BACKTEST_CONFIG) + + ################################### + # analyze + # If need a more detailed analysis, refer to: examples/train_and_bakctest.ipynb + ################################### + analysis = dict() + analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["excess_return_with_cost"] = risk_analysis( + report_normal["return"] - report_normal["bench"] - report_normal["cost"] + ) + analysis_df = pd.concat(analysis) # type: pd.DataFrame + print(analysis_df) diff --git a/qlib/contrib/model/tabnet.py b/qlib/contrib/model/tabnet.py new file mode 100644 index 000000000..63a75d26f --- /dev/null +++ b/qlib/contrib/model/tabnet.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import pandas as pd +from pytorch_tabnet.tab_model import TabNetRegressor + +from ...model.base import Model +from ...data.dataset import DatasetH +from ...data.dataset.handler import DataHandlerLP + +class TabNetModel(Model): + """TabNetModel Model""" + + def __init__(self, n_d, n_a, + n_steps, + gamma, + n_independent, + n_shared, + seed, + momentum, + lambda_sparse, + optimizer_params, + **kwargs): + self.model = None + + self.n_d = n_d + self.n_a = n_a + self.n_steps = n_steps + self.gamma = gamma + self.n_independent = n_independent + self.n_shared = n_shared + self.seed = seed + self.momentum = momentum + self.lambda_sparse = lambda_sparse + self.optimizer_params = optimizer_params + + def fit( + self, + dataset: DatasetH, + n_d=8, + n_a=8, + n_steps=3, + gamma=1.3, + n_independent=2, + n_shared=2, + seed=0, + momentum=0.02, + lambda_sparse=1e-3, + optimizer_params={'lr':2e-3}, + **kwargs + ): + + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + x_train, y_train = df_train["feature"].values, df_train["label"].values*100 + x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values*100 + + self.model = TabNetRegressor( + n_d=self.n_d, + n_a=self.n_a, + n_steps=self.n_steps, + gamma=self.gamma, + n_independent=self.n_independent, + n_shared=self.n_shared, + seed=self.seed, + momentum=self.momentum, + lambda_sparse=self.lambda_sparse, + optimizer_params=self.optimizer_params, + **kwargs + ) + self.model.fit(x_train, y_train, eval_set=[(x_valid, y_valid)]) + + def predict(self, dataset): + if self.model is None: + raise ValueError("model is not fitted yet!") + x_test = dataset.prepare("test", col_set="feature") + test_pred = self.model.predict(x_test.values) + return pd.Series(test_pred.reshape([-1]), index=x_test.index)