diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 725568de8..ad3945368 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -210,8 +210,12 @@ class ALSTM(Model): dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader - train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs) - valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs) + train_loader = DataLoader( + dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True + ) + valid_loader = DataLoader( + dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True + ) if save_path == None: save_path = create_save_path(save_path) diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 1e94f56e4..241debe61 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -258,8 +258,8 @@ class GATs(Model): sampler_train = DailyBatchSampler(dl_train) sampler_valid = DailyBatchSampler(dl_valid) - train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs) - valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs) + train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True) + valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True) if save_path == None: save_path = create_save_path(save_path) diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index bb6618b85..a0b240ef4 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -210,8 +210,12 @@ class GRU(Model): dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader - train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs) - valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs) + train_loader = DataLoader( + dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True + ) + valid_loader = DataLoader( + dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True + ) if save_path == None: save_path = create_save_path(save_path) diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index cf4f8fb9f..d0decc5fb 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -210,8 +210,12 @@ class LSTM(Model): dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader - train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs) - valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs) + train_loader = DataLoader( + dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True + ) + valid_loader = DataLoader( + dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True + ) if save_path == None: save_path = create_save_path(save_path)