diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py index ee9d5a4f8..8a55d0385 100644 --- a/qlib/contrib/model/catboost_model.py +++ b/qlib/contrib/model/catboost_model.py @@ -54,11 +54,13 @@ class CatBoostModel(Model): self.model.fit( train_pool, eval_set = valid_pool, - use_best_model = True + use_best_model = True, + **kwargs ) - - evals_result["train"] = list(self.model.get_evals_result().values())[0] - evals_result["valid"] = self.model.get_test_eval() + + evals_result = self.model.get_evals_result() + evals_result["train"] = list(evals_result["learn"].values())[0] + evals_result["valid"] = list(evals_result["validation"].values())[0] def predict(self, dataset):