1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00

Update pytorch_gats_ts.py

This commit is contained in:
Wendi Li
2021-01-17 12:09:48 +00:00
committed by you-n-g
parent 740c297618
commit fe60e40927

View File

@@ -249,8 +249,8 @@ class GATs(Model):
save_path=None,
):
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
@@ -332,7 +332,7 @@ class GATs(Model):
if not self._fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
sampler_test = DailyBatchSampler(dl_test)
test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)