diff --git a/examples/benchmarks/TFT/libs/tft_model.py b/examples/benchmarks/TFT/libs/tft_model.py index 658bae60f..3e6e4346e 100644 --- a/examples/benchmarks/TFT/libs/tft_model.py +++ b/examples/benchmarks/TFT/libs/tft_model.py @@ -721,12 +721,7 @@ class TemporalFusionTransformer(object): encoder_steps = self.num_encoder_steps # Inputs. - all_inputs = tf.keras.layers.Input( - shape=( - time_steps, - combined_input_size, - ) - ) + all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,)) unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs) @@ -866,10 +861,7 @@ class TemporalFusionTransformer(object): """Returns LSTM cell initialized with default parameters.""" if self.use_cudnn: lstm = tf.keras.layers.CuDNNLSTM( - self.hidden_layer_size, - return_sequences=True, - return_state=return_state, - stateful=False, + self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False, ) else: lstm = tf.keras.layers.LSTM( diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py index b5398e7f2..388ec7f14 100644 --- a/examples/benchmarks/TFT/tft.py +++ b/examples/benchmarks/TFT/tft.py @@ -234,7 +234,7 @@ class TFTModel(ModelFT): predict50 = format_score(p50_forecast, "pred", 1) predict90 = format_score(p90_forecast, "pred", 1) - predict = (predict50 + predict90)/2 # self.label_shift + predict = (predict50 + predict90) / 2 # self.label_shift # ===========================Predicting Process=========================== return predict