1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

Use callback in LGBM.train. (#974)

This commit is contained in:
Chia-hung Tai
2022-03-13 11:20:18 +08:00
committed by GitHub
parent 921c13cc90
commit 829ad9f5e9
2 changed files with 17 additions and 11 deletions

View File

@@ -68,17 +68,18 @@ class LGBModel(ModelFT, LightGBMFInt):
evals_result = {} # in case of unsafety of Python default values
ds_l = self._prepare_data(dataset, reweighter)
ds, names = list(zip(*ds_l))
early_stopping_callback = lgb.early_stopping(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
)
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)
self.model = lgb.train(
self.params,
ds[0], # training dataset
num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,
valid_sets=ds,
valid_names=names,
early_stopping_rounds=(
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
),
verbose_eval=verbose_eval,
evals_result=evals_result,
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
**kwargs,
)
for k in names:
@@ -110,6 +111,7 @@ class LGBModel(ModelFT, LightGBMFInt):
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
if dtrain.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
self.model = lgb.train(
self.params,
dtrain,
@@ -117,5 +119,5 @@ class LGBModel(ModelFT, LightGBMFInt):
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval,
callbacks=[verbose_eval_callback],
)

View File

@@ -110,18 +110,21 @@ class HFLGBModel(ModelFT, LightGBMFInt):
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
evals_result=None,
):
if evals_result is None:
evals_result = dict()
dtrain, dvalid = self._prepare_data(dataset)
early_stopping_callback = lgb.early_stopping(early_stopping_rounds)
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
evals_result_callback = lgb.record_evaluation(evals_result)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
@@ -147,6 +150,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
"""
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset)
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
self.model = lgb.train(
self.params,
dtrain,
@@ -154,5 +158,5 @@ class HFLGBModel(ModelFT, LightGBMFInt):
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval,
callbacks=[verbose_eval_callback],
)