diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py index 00a0e4846..94f4397c5 100644 --- a/qlib/contrib/model/pytorch_general_nn.py +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -17,6 +17,8 @@ import torch.nn as nn import torch.optim as optim from torch.utils.data import StackDataset +from qlib.data.dataset.weight import Reweighter + from .pytorch_utils import count_parameters from ...model.base import Model from ...data.dataset import DatasetH, TSDatasetH @@ -373,10 +375,6 @@ class GeneralPTNN(Model): def __init__( self, - d_feat=6, - hidden_size=64, - num_layers=2, - dropout=0.0, n_epochs=200, lr=0.001, metric="", @@ -387,17 +385,19 @@ class GeneralPTNN(Model): n_jobs=10, GPU=0, seed=None, - **kwargs + pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel", + pt_model_kwargs={ + "d_feat":6, + "hidden_size":64, + "num_layers":2, + "dropout":0., + }, ): # Set logger. - self.logger = get_module_logger("GRU") - self.logger.info("GRU pytorch version...") + self.logger = get_module_logger("GeneralPTNN") + self.logger.info("GeneralPTNN pytorch version...") # set hyper-parameters. - self.d_feat = d_feat - self.hidden_size = hidden_size - self.num_layers = num_layers - self.dropout = dropout self.n_epochs = n_epochs self.lr = lr self.metric = metric @@ -409,12 +409,11 @@ class GeneralPTNN(Model): self.n_jobs = n_jobs self.seed = seed + self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs + self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs}) + self.logger.info( - "GRU parameters setting:" - "\nd_feat : {}" - "\nhidden_size : {}" - "\nnum_layers : {}" - "\ndropout : {}" + "GeneralPTNN parameters setting:" "\nn_epochs : {}" "\nlr : {}" "\nmetric : {}" @@ -425,11 +424,9 @@ class GeneralPTNN(Model): "\ndevice : {}" "\nn_jobs : {}" "\nuse_GPU : {}" - "\nseed : {}".format( - d_feat, - hidden_size, - num_layers, - dropout, + "\nseed : {}" + "\npt_model_uri: {}" + "\npt_model_kwargs: {}".format( n_epochs, lr, metric, @@ -441,31 +438,28 @@ class GeneralPTNN(Model): n_jobs, self.use_gpu, seed, + pt_model_uri, + pt_model_kwargs, ) + ) if self.seed is not None: np.random.seed(self.seed) torch.manual_seed(self.seed) - self.GRU_model = GRUModel( - d_feat=self.d_feat, - hidden_size=self.hidden_size, - num_layers=self.num_layers, - dropout=self.dropout, - ) - self.logger.info("model:\n{:}".format(self.GRU_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model))) + self.logger.info("model:\n{:}".format(self.dnn_model)) + self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model))) if optimizer.lower() == "adam": - self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr) + self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": - self.train_optimizer = optim.SGD(self.GRU_model.parameters(), lr=self.lr) + self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr) else: raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) self.fitted = False - self.GRU_model.to(self.device) + self.dnn_model.to(self.device) @property def use_gpu(self): @@ -495,22 +489,22 @@ class GeneralPTNN(Model): raise ValueError("unknown metric `%s`" % self.metric) def train_epoch(self, data_loader): - self.GRU_model.train() + self.dnn_model.train() for data, weight in data_loader: feature = data[:, :, 0:-1].to(self.device) label = data[:, -1, -1].to(self.device) - pred = self.GRU_model(feature.float()) + pred = self.dnn_model(feature.float()) loss = self.loss_fn(pred, label, weight.to(self.device)) self.train_optimizer.zero_grad() loss.backward() - torch.nn.utils.clip_grad_value_(self.GRU_model.parameters(), 3.0) + torch.nn.utils.clip_grad_value_(self.dnn_model.parameters(), 3.0) self.train_optimizer.step() def test_epoch(self, data_loader): - self.GRU_model.eval() + self.dnn_model.eval() scores = [] losses = [] @@ -521,7 +515,7 @@ class GeneralPTNN(Model): label = data[:, -1, -1].to(self.device) with torch.no_grad(): - pred = self.GRU_model(feature.float()) + pred = self.dnn_model(feature.float()) loss = self.loss_fn(pred, label, weight.to(self.device)) losses.append(loss.item()) @@ -597,7 +591,7 @@ class GeneralPTNN(Model): best_score = val_score stop_steps = 0 best_epoch = step - best_param = copy.deepcopy(self.GRU_model.state_dict()) + best_param = copy.deepcopy(self.dnn_model.state_dict()) else: stop_steps += 1 if stop_steps >= self.early_stop: @@ -605,7 +599,7 @@ class GeneralPTNN(Model): break self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) - self.GRU_model.load_state_dict(best_param) + self.dnn_model.load_state_dict(best_param) torch.save(best_param, save_path) if self.use_gpu: @@ -618,14 +612,14 @@ class GeneralPTNN(Model): dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) dl_test.config(fillna_type="ffill+bfill") test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) - self.GRU_model.eval() + self.dnn_model.eval() preds = [] for data in test_loader: feature = data[:, :, 0:-1].to(self.device) with torch.no_grad(): - pred = self.GRU_model(feature.float()).detach().cpu().numpy() + pred = self.dnn_model(feature.float()).detach().cpu().numpy() preds.append(pred) diff --git a/tests/model/test_general_nn.py b/tests/model/test_general_nn.py index faf75b724..2fa485fda 100644 --- a/tests/model/test_general_nn.py +++ b/tests/model/test_general_nn.py @@ -55,11 +55,24 @@ class TestNN(TestAutoData): # tabular dataset tbds = DatasetH(handler=data_handler, segments=segments) + + model_l = [ + GeneralPTNN( + n_epochs=2, + pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel", + pt_model_kwargs={ + "d_feat":3, + "hidden_size":8, + "num_layers":1, + "dropout":0., + }, + ), + ] - for ds in (tsds, tbds): - ptnn = GeneralPTNN() - ptnn.fit(ds) # It works - ptnn.predict(ds) # It works + for ds, model in zip((tsds, tbds), model_l): + model.fit(ds) # It works + model.predict(ds) # It works + break if __name__ == "__main__":