From 0f9312593d7cd2ddcbd4b4039cdb8563e74619a2 Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 9 Jul 2024 09:11:06 +0000 Subject: [PATCH] Remove some deprecated code --- qlib/contrib/model/pytorch_general_nn.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py index 0b361e8ac..c787ab2f5 100644 --- a/qlib/contrib/model/pytorch_general_nn.py +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -163,7 +163,6 @@ class GeneralPTNN(Model): else: self.scheduler = scheduler(optimizer=self.train_optimizer) - self.fitted = False self.dnn_model.to(self.device) @property @@ -327,29 +326,10 @@ class GeneralPTNN(Model): return preds def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): - if not self.fitted: - raise ValueError("model is not fitted yet!") x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) preds = self._nn_predict(x_test_pd) return pd.Series(preds.reshape(-1), index=x_test_pd.index) - def save(self, filename, **kwargs): - with save_multiple_parts_file(filename) as model_dir: - model_path = os.path.join(model_dir, os.path.split(model_dir)[-1]) - # Save model - torch.save(self.dnn_model.state_dict(), model_path) - - def load(self, buffer, **kwargs): - with unpack_archive_with_buffer(buffer) as model_dir: - # Get model name - _model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[ - 0 - ] - _model_path = os.path.join(model_dir, _model_name) - # Load model - self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device)) - self.fitted = True - class AverageMeter: """Computes and stores the average and current value"""