1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00

Fix bugs in use_gpu

This commit is contained in:
D-X-Y
2021-03-11 19:10:32 -08:00
parent db59713d36
commit d38b8d6001
9 changed files with 10 additions and 13 deletions

View File

@@ -138,7 +138,7 @@ class ALSTM(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2

View File

@@ -143,7 +143,7 @@ class ALSTM(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2

View File

@@ -151,7 +151,7 @@ class GATs(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2

View File

@@ -172,7 +172,7 @@ class GATs(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2

View File

@@ -138,7 +138,7 @@ class GRU(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2

View File

@@ -143,7 +143,7 @@ class GRU(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2

View File

@@ -134,7 +134,7 @@ class LSTM(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2
@@ -291,10 +291,7 @@ class LSTM(Model):
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
if self.use_gpu:
pred = self.lstm_model(x_batch).detach().cpu().numpy()
else:
pred = self.lstm_model(x_batch).detach().numpy()
pred = self.lstm_model(x_batch).detach().cpu().numpy()
preds.append(pred)

View File

@@ -139,7 +139,7 @@ class LSTM(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def mse(self, pred, label):
loss = (pred - label) ** 2

View File

@@ -118,7 +118,7 @@ class TabnetModel(Model):
@property
def use_gpu(self):
self.device == torch.device("cpu")
self.device != torch.device("cpu")
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
get_or_create_path(pretrain_file)