diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index 7b999d0a1..fb422f491 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -293,22 +293,6 @@ class GRU(Model): preds = self.gru_model(x_test).detach().numpy() return pd.Series(preds, index=index) - def save(self, filename, **kwargs): - with save_multiple_parts_file(filename) as model_dir: - model_path = os.path.join(model_dir, os.path.split(model_dir)[-1]) - # Save model - torch.save(self.gru_model.state_dict(), model_path) - - def load(self, buffer, **kwargs): - with unpack_archive_with_buffer(buffer) as model_dir: - # Get model name - _model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[ - 0 - ] - _model_path = os.path.join(model_dir, _model_name) - # Load model - self.gru_model.load_state_dict(torch.load(_model_path)) - self._fitted = True class AverageMeter(object):