mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-05 12:00:58 +08:00
add map_location to torch.load to make it work when cuda is unavailable (#782)
This commit is contained in:
@@ -260,7 +260,7 @@ class GATs(Model):
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
|
||||
@@ -276,7 +276,7 @@ class GATs(Model):
|
||||
|
||||
if self.model_path is not None:
|
||||
self.logger.info("Loading pretrained model...")
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
|
||||
@@ -257,7 +257,7 @@ class DNNModelPytorch(Model):
|
||||
self.scheduler.step(cur_loss_val)
|
||||
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path))
|
||||
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -296,7 +296,7 @@ class DNNModelPytorch(Model):
|
||||
]
|
||||
_model_path = os.path.join(model_dir, _model_name)
|
||||
# Load model
|
||||
self.dnn_model.load_state_dict(torch.load(_model_path))
|
||||
self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device))
|
||||
self.fitted = True
|
||||
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ class TabnetModel(Model):
|
||||
self.logger.info("Pretrain...")
|
||||
self.pretrain_fn(dataset, self.pretrain_file)
|
||||
self.logger.info("Load Pretrain model")
|
||||
self.tabnet_model.load_state_dict(torch.load(self.pretrain_file))
|
||||
self.tabnet_model.load_state_dict(torch.load(self.pretrain_file, map_location=self.device))
|
||||
|
||||
# adding one more linear layer to fit the final output dimension
|
||||
self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device)
|
||||
|
||||
@@ -350,9 +350,9 @@ class TCTS(Model):
|
||||
break
|
||||
|
||||
print("best loss:", best_loss, "@", best_epoch)
|
||||
best_param = torch.load(save_path + "_fore_model.bin")
|
||||
best_param = torch.load(save_path + "_fore_model.bin", map_location=self.device)
|
||||
self.fore_model.load_state_dict(best_param)
|
||||
best_param = torch.load(save_path + "_weight_model.bin")
|
||||
best_param = torch.load(save_path + "_weight_model.bin", map_location=self.device)
|
||||
self.weight_model.load_state_dict(best_param)
|
||||
self.fitted = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user