1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Update pytorch_alstm_ts.py

This commit is contained in:
Wendi Li
2021-01-17 12:08:18 +00:00
committed by you-n-g
parent b4a088efe8
commit 740c297618

View File

@@ -204,8 +204,8 @@ class ALSTM(Model):
verbose=True,
save_path=None,
):
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
@@ -260,7 +260,7 @@ class ALSTM(Model):
if not self._fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.ALSTM_model.eval()