From 622303b83a37c9199bf969a5e2e3de1072746bc4 Mon Sep 17 00:00:00 2001 From: Chao Ning Date: Tue, 28 Dec 2021 16:02:04 +0000 Subject: [PATCH] add map_location to torch.load to make it work when cuda is unavailable (#782) --- qlib/contrib/model/pytorch_gats.py | 2 +- qlib/contrib/model/pytorch_gats_ts.py | 2 +- qlib/contrib/model/pytorch_nn.py | 4 ++-- qlib/contrib/model/pytorch_tabnet.py | 2 +- qlib/contrib/model/pytorch_tcts.py | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 7f379c3b9..7c2c99432 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -260,7 +260,7 @@ class GATs(Model): if self.model_path is not None: self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path)) + pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) model_dict = self.GAT_model.state_dict() pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict} diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 401719275..53a7817e2 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -276,7 +276,7 @@ class GATs(Model): if self.model_path is not None: self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path)) + pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) model_dict = self.GAT_model.state_dict() pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict} diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 780dc4b91..7086cdb5a 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -257,7 +257,7 @@ class DNNModelPytorch(Model): self.scheduler.step(cur_loss_val) # restore the optimal parameters after training - self.dnn_model.load_state_dict(torch.load(save_path)) + self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device)) if self.use_gpu: torch.cuda.empty_cache() @@ -296,7 +296,7 @@ class DNNModelPytorch(Model): ] _model_path = os.path.join(model_dir, _model_name) # Load model - self.dnn_model.load_state_dict(torch.load(_model_path)) + self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device)) self.fitted = True diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 504048210..e0e2093e8 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -160,7 +160,7 @@ class TabnetModel(Model): self.logger.info("Pretrain...") self.pretrain_fn(dataset, self.pretrain_file) self.logger.info("Load Pretrain model") - self.tabnet_model.load_state_dict(torch.load(self.pretrain_file)) + self.tabnet_model.load_state_dict(torch.load(self.pretrain_file, map_location=self.device)) # adding one more linear layer to fit the final output dimension self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device) diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py index 7cd59be9b..d813ae01f 100644 --- a/qlib/contrib/model/pytorch_tcts.py +++ b/qlib/contrib/model/pytorch_tcts.py @@ -350,9 +350,9 @@ class TCTS(Model): break print("best loss:", best_loss, "@", best_epoch) - best_param = torch.load(save_path + "_fore_model.bin") + best_param = torch.load(save_path + "_fore_model.bin", map_location=self.device) self.fore_model.load_state_dict(best_param) - best_param = torch.load(save_path + "_weight_model.bin") + best_param = torch.load(save_path + "_weight_model.bin", map_location=self.device) self.weight_model.load_state_dict(best_param) self.fitted = True