mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
Use callback in LGBM.train. (#974)
This commit is contained in:
@@ -68,17 +68,18 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
evals_result = {} # in case of unsafety of Python default values
|
||||
ds_l = self._prepare_data(dataset, reweighter)
|
||||
ds, names = list(zip(*ds_l))
|
||||
early_stopping_callback = lgb.early_stopping(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
evals_result_callback = lgb.record_evaluation(evals_result)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
ds[0], # training dataset
|
||||
num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round,
|
||||
valid_sets=ds,
|
||||
valid_names=names,
|
||||
early_stopping_rounds=(
|
||||
self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds
|
||||
),
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
|
||||
**kwargs,
|
||||
)
|
||||
for k in names:
|
||||
@@ -110,6 +111,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
@@ -117,5 +119,5 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
callbacks=[verbose_eval_callback],
|
||||
)
|
||||
|
||||
@@ -110,18 +110,21 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
evals_result=None,
|
||||
):
|
||||
if evals_result is None:
|
||||
evals_result = dict()
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
early_stopping_callback = lgb.early_stopping(early_stopping_rounds)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
evals_result_callback = lgb.record_evaluation(evals_result)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback],
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
@@ -147,6 +150,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
verbose_eval_callback = lgb.log_evaluation(period=verbose_eval)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
@@ -154,5 +158,5 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
callbacks=[verbose_eval_callback],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user