1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Update GRU model.

This commit is contained in:
lwwang1995
2020-11-12 13:11:14 +08:00
parent 52c0c4b7a8
commit d45aa86fb5

View File

@@ -293,22 +293,6 @@ class GRU(Model):
preds = self.gru_model(x_test).detach().numpy()
return pd.Series(preds, index=index)
def save(self, filename, **kwargs):
with save_multiple_parts_file(filename) as model_dir:
model_path = os.path.join(model_dir, os.path.split(model_dir)[-1])
# Save model
torch.save(self.gru_model.state_dict(), model_path)
def load(self, buffer, **kwargs):
with unpack_archive_with_buffer(buffer) as model_dir:
# Get model name
_model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[
0
]
_model_path = os.path.join(model_dir, _model_name)
# Load model
self.gru_model.load_state_dict(torch.load(_model_path))
self._fitted = True
class AverageMeter(object):