1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 11:30:57 +08:00

fix model when using single feature

This commit is contained in:
Dong Zhou
2020-11-24 21:03:52 +08:00
parent 73b280754d
commit e819879232
3 changed files with 4 additions and 4 deletions

View File

@@ -61,7 +61,7 @@ class CatBoostModel(Model):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
if __name__ == "__main__":

View File

@@ -16,7 +16,7 @@ class LGBModel(ModelFT):
def __init__(self, loss="mse", **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
self.params = {"objective": loss}
self.params = {"objective": loss, 'verbosity': -1}
self.params.update(kwargs)
self.model = None
@@ -65,7 +65,7 @@ class LGBModel(ModelFT):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(np.squeeze(x_test.values)), index=x_test.index)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
"""

View File

@@ -61,4 +61,4 @@ class XGBModel(Model):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
return pd.Series(self.model.predict(xgb.DMatrix(np.squeeze(x_test.values))), index=x_test.index)
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)