diff --git a/examples/benchmarks/TFT/libs/tft_model.py b/examples/benchmarks/TFT/libs/tft_model.py index 3e6e4346e..658bae60f 100644 --- a/examples/benchmarks/TFT/libs/tft_model.py +++ b/examples/benchmarks/TFT/libs/tft_model.py @@ -721,7 +721,12 @@ class TemporalFusionTransformer(object): encoder_steps = self.num_encoder_steps # Inputs. - all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,)) + all_inputs = tf.keras.layers.Input( + shape=( + time_steps, + combined_input_size, + ) + ) unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs) @@ -861,7 +866,10 @@ class TemporalFusionTransformer(object): """Returns LSTM cell initialized with default parameters.""" if self.use_cudnn: lstm = tf.keras.layers.CuDNNLSTM( - self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False, + self.hidden_layer_size, + return_sequences=True, + return_state=return_state, + stateful=False, ) else: lstm = tf.keras.layers.LSTM(