mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 01:51:18 +08:00
chore: remove hard code input dimension of model pytorch_tcts (#843)
Co-authored-by: Jiabao Qu <qujiabao@logiocean.com>
This commit is contained in:
@@ -56,6 +56,7 @@ class TCTS(Model):
|
||||
loss="mse",
|
||||
fore_optimizer="adam",
|
||||
weight_optimizer="adam",
|
||||
input_dim=360,
|
||||
output_dim=5,
|
||||
fore_lr=5e-7,
|
||||
weight_lr=5e-7,
|
||||
@@ -83,6 +84,7 @@ class TCTS(Model):
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.fore_lr = fore_lr
|
||||
self.weight_lr = weight_lr
|
||||
@@ -139,7 +141,6 @@ class TCTS(Model):
|
||||
raise NotImplementedError("mode {} is not supported!".format(self.mode))
|
||||
|
||||
def train_epoch(self, x_train, y_train, x_valid, y_valid):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
@@ -297,7 +298,7 @@ class TCTS(Model):
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.weight_model = MLPModel(
|
||||
d_feat=360 + 3 * self.output_dim + 1,
|
||||
d_feat=self.input_dim + 3 * self.output_dim + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
|
||||
Reference in New Issue
Block a user