1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-05 12:00:58 +08:00

add map_location to torch.load to make it work when cuda is unavailable (#782)

This commit is contained in:
Chao Ning
2021-12-28 16:02:04 +00:00
committed by GitHub
parent 6bafd0a09b
commit 622303b83a
5 changed files with 7 additions and 7 deletions

View File

@@ -260,7 +260,7 @@ class GATs(Model):
if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}

View File

@@ -276,7 +276,7 @@ class GATs(Model):
if self.model_path is not None:
self.logger.info("Loading pretrained model...")
pretrained_model.load_state_dict(torch.load(self.model_path))
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}

View File

@@ -257,7 +257,7 @@ class DNNModelPytorch(Model):
self.scheduler.step(cur_loss_val)
# restore the optimal parameters after training
self.dnn_model.load_state_dict(torch.load(save_path))
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))
if self.use_gpu:
torch.cuda.empty_cache()
@@ -296,7 +296,7 @@ class DNNModelPytorch(Model):
]
_model_path = os.path.join(model_dir, _model_name)
# Load model
self.dnn_model.load_state_dict(torch.load(_model_path))
self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device))
self.fitted = True

View File

@@ -160,7 +160,7 @@ class TabnetModel(Model):
self.logger.info("Pretrain...")
self.pretrain_fn(dataset, self.pretrain_file)
self.logger.info("Load Pretrain model")
self.tabnet_model.load_state_dict(torch.load(self.pretrain_file))
self.tabnet_model.load_state_dict(torch.load(self.pretrain_file, map_location=self.device))
# adding one more linear layer to fit the final output dimension
self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device)

View File

@@ -350,9 +350,9 @@ class TCTS(Model):
break
print("best loss:", best_loss, "@", best_epoch)
best_param = torch.load(save_path + "_fore_model.bin")
best_param = torch.load(save_path + "_fore_model.bin", map_location=self.device)
self.fore_model.load_state_dict(best_param)
best_param = torch.load(save_path + "_weight_model.bin")
best_param = torch.load(save_path + "_weight_model.bin", map_location=self.device)
self.weight_model.load_state_dict(best_param)
self.fitted = True