From 9abc0b0d4f4de7bd65d0e6392ba75115089c2d24 Mon Sep 17 00:00:00 2001 From: Wendi Li Date: Sun, 17 Jan 2021 12:10:43 +0000 Subject: [PATCH] Update pytorch_gru_ts.py --- qlib/contrib/model/pytorch_gru_ts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 149c9f8d0..144d97031 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -204,8 +204,8 @@ class GRU(Model): verbose=True, save_path=None, ): - dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader @@ -260,7 +260,7 @@ class GRU(Model): if not self._fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) dl_test.config(fillna_type="ffill+bfill") test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) self.GRU_model.eval()