diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_early_stop_Alpha158.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_early_stop_Alpha158.yaml new file mode 100644 index 000000000..b3c38870e --- /dev/null +++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_early_stop_Alpha158.yaml @@ -0,0 +1,95 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + signal: + - + - + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: DEnsembleModel + module_path: qlib.contrib.model.double_ensemble + kwargs: + base_model: "gbm" + loss: mse + num_models: 3 + enable_sr: True + enable_fs: True + alpha1: 1 + alpha2: 1 + bins_sr: 10 + bins_fs: 5 + decay: 0.5 + sample_ratios: + - 0.8 + - 0.7 + - 0.6 + - 0.5 + - 0.4 + sub_weights: + - 1 + - 1 + - 1 + epochs: 1000 + early_stopping_rounds: 50 + colsample_bytree: 0.8879 + learning_rate: 0.2 + subsample: 0.8789 + lambda_l1: 205.6999 + lambda_l2: 580.9768 + max_depth: 8 + num_leaves: 210 + num_threads: 20 + verbosity: -1 + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/qlib/contrib/model/double_ensemble.py b/qlib/contrib/model/double_ensemble.py index 50c3d22b4..f0b2188d0 100644 --- a/qlib/contrib/model/double_ensemble.py +++ b/qlib/contrib/model/double_ensemble.py @@ -30,6 +30,7 @@ class DEnsembleModel(Model, FeatureInt): sample_ratios=None, sub_weights=None, epochs=100, + early_stopping_rounds=None, **kwargs ): self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" @@ -59,6 +60,7 @@ class DEnsembleModel(Model, FeatureInt): self.params = {"objective": loss} self.params.update(kwargs) self.loss = loss + self.early_stopping_rounds = early_stopping_rounds def fit(self, dataset: DatasetH): df_train, df_valid = dataset.prepare( @@ -103,14 +105,19 @@ class DEnsembleModel(Model, FeatureInt): def train_submodel(self, df_train, df_valid, weights, features): dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features) evals_result = dict() + + callbacks = [lgb.log_evaluation(20), lgb.record_evaluation(evals_result)] + if self.early_stopping_rounds: + callbacks.append(lgb.early_stopping(self.early_stopping_rounds)) + self.logger.info("Training with early_stopping...") + model = lgb.train( self.params, dtrain, num_boost_round=self.epochs, valid_sets=[dtrain, dvalid], valid_names=["train", "valid"], - verbose_eval=20, - evals_result=evals_result, + callbacks=callbacks, ) evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0]