From fe60e409278a3350d4116a24c899f7a5bf1eae96 Mon Sep 17 00:00:00 2001 From: Wendi Li Date: Sun, 17 Jan 2021 12:09:48 +0000 Subject: [PATCH] Update pytorch_gats_ts.py --- qlib/contrib/model/pytorch_gats_ts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 7b0669dba..c3b8a2f06 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -249,8 +249,8 @@ class GATs(Model): save_path=None, ): - dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader @@ -332,7 +332,7 @@ class GATs(Model): if not self._fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) dl_test.config(fillna_type="ffill+bfill") sampler_test = DailyBatchSampler(dl_test) test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)