diff --git a/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py b/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py index da3d14343..44a9284f7 100644 --- a/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py +++ b/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py @@ -194,10 +194,10 @@ class Alpha158Formatter(GenericDataFormatter): """Returns fixed model parameters for experiments.""" fixed_params = { - "total_time_steps": 16 + 6, - "num_encoder_steps": 16, + "total_time_steps": 6 + 6, + "num_encoder_steps": 6, "num_epochs": 100, - "early_stopping_patience": 5, + "early_stopping_patience": 10, "multiprocessing_workers": 5, } @@ -207,11 +207,11 @@ class Alpha158Formatter(GenericDataFormatter): """Returns default optimised model parameters.""" model_params = { - "dropout_rate": 0.3, - "hidden_layer_size": 160, - "learning_rate": 0.001, - "minibatch_size": 64, - "max_gradient_norm": 0.01, + "dropout_rate": 0.4, + "hidden_layer_size": 16, + "learning_rate": 0.0001, + "minibatch_size": 128, + "max_gradient_norm": 0.0135, "num_heads": 1, "stack_size": 1, } diff --git a/examples/benchmarks/TFT/libs/tft_model.py b/examples/benchmarks/TFT/libs/tft_model.py index 658bae60f..3e6e4346e 100644 --- a/examples/benchmarks/TFT/libs/tft_model.py +++ b/examples/benchmarks/TFT/libs/tft_model.py @@ -721,12 +721,7 @@ class TemporalFusionTransformer(object): encoder_steps = self.num_encoder_steps # Inputs. - all_inputs = tf.keras.layers.Input( - shape=( - time_steps, - combined_input_size, - ) - ) + all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,)) unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs) @@ -866,10 +861,7 @@ class TemporalFusionTransformer(object): """Returns LSTM cell initialized with default parameters.""" if self.use_cudnn: lstm = tf.keras.layers.CuDNNLSTM( - self.hidden_layer_size, - return_sequences=True, - return_state=return_state, - stateful=False, + self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False, ) else: lstm = tf.keras.layers.LSTM( diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py index 3387a5947..388ec7f14 100644 --- a/examples/benchmarks/TFT/tft.py +++ b/examples/benchmarks/TFT/tft.py @@ -82,7 +82,7 @@ def process_predicted(df, col_name): """ df_res = df.copy() - df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+5": col_name}) + df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name}) df_res = df_res.set_index(["datetime", "instrument"]).sort_index() df_res = df_res[[col_name]] return df_res @@ -232,7 +232,9 @@ class TFTModel(ModelFT): p90_forecast = self.data_formatter.format_predictions(output_map["p90"]) tf.keras.backend.set_session(default_keras_session) - predict = format_score(p90_forecast, "pred", 0) # self.label_shift + predict50 = format_score(p50_forecast, "pred", 1) + predict90 = format_score(p90_forecast, "pred", 1) + predict = (predict50 + predict90) / 2 # self.label_shift # ===========================Predicting Process=========================== return predict