1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

Update pytorch_gru_ts.py

This commit is contained in:
Wendi Li
2021-01-17 12:10:43 +00:00
committed by you-n-g
parent fe60e40927
commit 9abc0b0d4f

View File

@@ -204,8 +204,8 @@ class GRU(Model):
verbose=True,
save_path=None,
):
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
@@ -260,7 +260,7 @@ class GRU(Model):
if not self._fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.GRU_model.eval()