diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 7b0669dba..c3b8a2f06 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -249,8 +249,8 @@ class GATs(Model): save_path=None, ): - dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader @@ -332,7 +332,7 @@ class GATs(Model): if not self._fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) dl_test.config(fillna_type="ffill+bfill") sampler_test = DailyBatchSampler(dl_test) test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)