1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 01:51:18 +08:00

Remove some deprecated code

This commit is contained in:
Young
2024-07-09 09:11:06 +00:00
parent 4405cb784f
commit 0f9312593d

View File

@@ -163,7 +163,6 @@ class GeneralPTNN(Model):
else:
self.scheduler = scheduler(optimizer=self.train_optimizer)
self.fitted = False
self.dnn_model.to(self.device)
@property
@@ -327,29 +326,10 @@ class GeneralPTNN(Model):
return preds
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
preds = self._nn_predict(x_test_pd)
return pd.Series(preds.reshape(-1), index=x_test_pd.index)
def save(self, filename, **kwargs):
with save_multiple_parts_file(filename) as model_dir:
model_path = os.path.join(model_dir, os.path.split(model_dir)[-1])
# Save model
torch.save(self.dnn_model.state_dict(), model_path)
def load(self, buffer, **kwargs):
with unpack_archive_with_buffer(buffer) as model_dir:
# Get model name
_model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[
0
]
_model_path = os.path.join(model_dir, _model_name)
# Load model
self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device))
self.fitted = True
class AverageMeter:
"""Computes and stores the average and current value"""