mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
Update pytorch_gats_ts.py
This commit is contained in:
@@ -249,8 +249,8 @@ class GATs(Model):
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
@@ -332,7 +332,7 @@ class GATs(Model):
|
||||
if not self._fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
sampler_test = DailyBatchSampler(dl_test)
|
||||
test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs)
|
||||
|
||||
Reference in New Issue
Block a user