diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 149c9f8d0..144d97031 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -204,8 +204,8 @@ class GRU(Model): verbose=True, save_path=None, ): - dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader @@ -260,7 +260,7 @@ class GRU(Model): if not self._fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) dl_test.config(fillna_type="ffill+bfill") test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) self.GRU_model.eval()