diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index 135ce2e39..0ce8542e9 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -138,7 +138,7 @@ class ALSTM(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def mse(self, pred, label): loss = (pred - label) ** 2 diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 7b4eddbbb..21867d951 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -143,7 +143,7 @@ class ALSTM(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def mse(self, pred, label): loss = (pred - label) ** 2 diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 906af189b..0c66211b8 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -151,7 +151,7 @@ class GATs(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def mse(self, pred, label): loss = (pred - label) ** 2 diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 4cce0e960..1c702da0f 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -172,7 +172,7 @@ class GATs(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def mse(self, pred, label): loss = (pred - label) ** 2 diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index bf4e05cfd..e25e13212 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -138,7 +138,7 @@ class GRU(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def mse(self, pred, label): loss = (pred - label) ** 2 diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 36b2e1492..6edddd755 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -143,7 +143,7 @@ class GRU(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def mse(self, pred, label): loss = (pred - label) ** 2 diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index e8fa94abc..28d07665c 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -134,7 +134,7 @@ class LSTM(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def mse(self, pred, label): loss = (pred - label) ** 2 @@ -291,10 +291,7 @@ class LSTM(Model): x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device) with torch.no_grad(): - if self.use_gpu: - pred = self.lstm_model(x_batch).detach().cpu().numpy() - else: - pred = self.lstm_model(x_batch).detach().numpy() + pred = self.lstm_model(x_batch).detach().cpu().numpy() preds.append(pred) diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index dfac8d1f6..c6e99d19e 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -139,7 +139,7 @@ class LSTM(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def mse(self, pred, label): loss = (pred - label) ** 2 diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 950c2dffa..34d6d82f0 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -118,7 +118,7 @@ class TabnetModel(Model): @property def use_gpu(self): - self.device == torch.device("cpu") + self.device != torch.device("cpu") def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"): get_or_create_path(pretrain_file)