diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py index d813ae01f..8cb56930d 100644 --- a/qlib/contrib/model/pytorch_tcts.py +++ b/qlib/contrib/model/pytorch_tcts.py @@ -56,6 +56,7 @@ class TCTS(Model): loss="mse", fore_optimizer="adam", weight_optimizer="adam", + input_dim=360, output_dim=5, fore_lr=5e-7, weight_lr=5e-7, @@ -83,6 +84,7 @@ class TCTS(Model): self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu") self.use_gpu = torch.cuda.is_available() self.seed = seed + self.input_dim = input_dim self.output_dim = output_dim self.fore_lr = fore_lr self.weight_lr = weight_lr @@ -139,7 +141,6 @@ class TCTS(Model): raise NotImplementedError("mode {} is not supported!".format(self.mode)) def train_epoch(self, x_train, y_train, x_valid, y_valid): - x_train_values = x_train.values y_train_values = np.squeeze(y_train.values) @@ -297,7 +298,7 @@ class TCTS(Model): dropout=self.dropout, ) self.weight_model = MLPModel( - d_feat=360 + 3 * self.output_dim + 1, + d_feat=self.input_dim + 3 * self.output_dim + 1, hidden_size=self.hidden_size, num_layers=self.num_layers, dropout=self.dropout,