mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
Update models to enable save/load
This commit is contained in:
@@ -130,7 +130,7 @@ class ALSTM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
def mse(self, pred, label):
|
||||
@@ -238,7 +238,7 @@ class ALSTM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -270,7 +270,7 @@ class ALSTM(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
|
||||
@@ -135,7 +135,7 @@ class ALSTM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
def mse(self, pred, label):
|
||||
@@ -225,7 +225,7 @@ class ALSTM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -257,7 +257,7 @@ class ALSTM(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
|
||||
@@ -142,7 +142,7 @@ class GATs(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
def mse(self, pred, label):
|
||||
@@ -275,7 +275,7 @@ class GATs(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -307,7 +307,7 @@ class GATs(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
|
||||
@@ -164,7 +164,7 @@ class GATs(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
def mse(self, pred, label):
|
||||
@@ -297,7 +297,7 @@ class GATs(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -329,7 +329,7 @@ class GATs(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
|
||||
@@ -130,7 +130,7 @@ class GRU(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.gru_model.to(self.device)
|
||||
|
||||
def mse(self, pred, label):
|
||||
@@ -238,7 +238,7 @@ class GRU(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -270,7 +270,7 @@ class GRU(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
|
||||
@@ -135,7 +135,7 @@ class GRU(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.GRU_model.to(self.device)
|
||||
|
||||
def mse(self, pred, label):
|
||||
@@ -225,7 +225,7 @@ class GRU(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -257,7 +257,7 @@ class GRU(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
|
||||
@@ -130,7 +130,7 @@ class LSTM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.lstm_model.to(self.device)
|
||||
|
||||
def mse(self, pred, label):
|
||||
@@ -238,7 +238,7 @@ class LSTM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -270,7 +270,7 @@ class LSTM(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
|
||||
@@ -135,7 +135,7 @@ class LSTM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.LSTM_model.to(self.device)
|
||||
|
||||
def mse(self, pred, label):
|
||||
@@ -225,7 +225,7 @@ class LSTM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -257,7 +257,7 @@ class LSTM(Model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
|
||||
@@ -150,7 +150,7 @@ class DNNModelPytorch(Model):
|
||||
eps=1e-08,
|
||||
)
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
def fit(
|
||||
@@ -180,7 +180,7 @@ class DNNModelPytorch(Model):
|
||||
evals_result["valid"] = []
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
# return
|
||||
# prepare training data
|
||||
x_train_values = torch.from_numpy(x_train.values).float()
|
||||
@@ -265,7 +265,7 @@ class DNNModelPytorch(Model):
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test_pd = dataset.prepare("test", col_set="feature")
|
||||
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)
|
||||
|
||||
@@ -302,7 +302,7 @@ class SFM(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self._fitted = False
|
||||
self.fitted = False
|
||||
self.sfm_model.to(self.device)
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
@@ -386,7 +386,7 @@ class SFM(Model):
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
@@ -435,7 +435,7 @@ class SFM(Model):
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
|
||||
@@ -88,6 +88,7 @@ class TabnetModel(Model):
|
||||
"\nGPU : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
|
||||
)
|
||||
self.fitted = False
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
@@ -187,7 +188,7 @@ class TabnetModel(Model):
|
||||
evals_result["valid"] = []
|
||||
|
||||
self.logger.info("training...")
|
||||
self._fitted = True
|
||||
self.fitted = True
|
||||
|
||||
for epoch_idx in range(self.n_epochs):
|
||||
self.logger.info("epoch: %s" % (epoch_idx))
|
||||
@@ -212,7 +213,7 @@ class TabnetModel(Model):
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self._fitted:
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
|
||||
@@ -478,13 +478,13 @@ class DatasetProvider(abc.ABC):
|
||||
|
||||
data = pd.DataFrame(obj)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
data.index = _calendar[data.index.values.astype(np.int)]
|
||||
data.index = _calendar[data.index.values.astype(int)]
|
||||
data.index.names = ["datetime"]
|
||||
|
||||
if spans is None:
|
||||
return data
|
||||
else:
|
||||
mask = np.zeros(len(data), dtype=np.bool)
|
||||
mask = np.zeros(len(data), dtype=bool)
|
||||
for begin, end in spans:
|
||||
mask |= (data.index >= begin) & (data.index <= end)
|
||||
return data[mask]
|
||||
|
||||
Reference in New Issue
Block a user