mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 18:11:18 +08:00
Fix bugs in use_gpu
This commit is contained in:
@@ -138,7 +138,7 @@ class ALSTM(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
|
||||
@@ -143,7 +143,7 @@ class ALSTM(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
|
||||
@@ -151,7 +151,7 @@ class GATs(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
|
||||
@@ -172,7 +172,7 @@ class GATs(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
|
||||
@@ -138,7 +138,7 @@ class GRU(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
|
||||
@@ -143,7 +143,7 @@ class GRU(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
|
||||
@@ -134,7 +134,7 @@ class LSTM(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
@@ -291,10 +291,7 @@ class LSTM(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.lstm_model(x_batch).detach().numpy()
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ class LSTM(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
|
||||
@@ -118,7 +118,7 @@ class TabnetModel(Model):
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
self.device == torch.device("cpu")
|
||||
self.device != torch.device("cpu")
|
||||
|
||||
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
|
||||
get_or_create_path(pretrain_file)
|
||||
|
||||
Reference in New Issue
Block a user