1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Add early stopping to double ensemble model, add example (#1375)

* Add early stopping to double ensemble model, add example

* Fix lint error
This commit is contained in:
Di
2022-11-28 14:02:44 +08:00
committed by GitHub
parent c4ee9ff882
commit 7e5bab599a
2 changed files with 104 additions and 2 deletions

View File

@@ -0,0 +1,95 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: DEnsembleModel
module_path: qlib.contrib.model.double_ensemble
kwargs:
base_model: "gbm"
loss: mse
num_models: 3
enable_sr: True
enable_fs: True
alpha1: 1
alpha2: 1
bins_sr: 10
bins_fs: 5
decay: 0.5
sample_ratios:
- 0.8
- 0.7
- 0.6
- 0.5
- 0.4
sub_weights:
- 1
- 1
- 1
epochs: 1000
early_stopping_rounds: 50
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
verbosity: -1
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -30,6 +30,7 @@ class DEnsembleModel(Model, FeatureInt):
sample_ratios=None,
sub_weights=None,
epochs=100,
early_stopping_rounds=None,
**kwargs
):
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
@@ -59,6 +60,7 @@ class DEnsembleModel(Model, FeatureInt):
self.params = {"objective": loss}
self.params.update(kwargs)
self.loss = loss
self.early_stopping_rounds = early_stopping_rounds
def fit(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
@@ -103,14 +105,19 @@ class DEnsembleModel(Model, FeatureInt):
def train_submodel(self, df_train, df_valid, weights, features):
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
evals_result = dict()
callbacks = [lgb.log_evaluation(20), lgb.record_evaluation(evals_result)]
if self.early_stopping_rounds:
callbacks.append(lgb.early_stopping(self.early_stopping_rounds))
self.logger.info("Training with early_stopping...")
model = lgb.train(
self.params,
dtrain,
num_boost_round=self.epochs,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
verbose_eval=20,
evals_result=evals_result,
callbacks=callbacks,
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]