mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 19:10:58 +08:00
Merge pull request #322 from Derek-Wds/bug
Fix pytorch ts model loader bug
This commit is contained in:
@@ -210,8 +210,12 @@ class ALSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
|
||||
@@ -258,8 +258,8 @@ class GATs(Model):
|
||||
sampler_train = DailyBatchSampler(dl_train)
|
||||
sampler_valid = DailyBatchSampler(dl_valid)
|
||||
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
|
||||
@@ -210,8 +210,12 @@ class GRU(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
|
||||
@@ -210,8 +210,12 @@ class LSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
|
||||
Reference in New Issue
Block a user