1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

black format

This commit is contained in:
Dong Zhou
2020-11-29 17:17:03 +08:00
parent 33f50b3cee
commit 0fb0109f9c

View File

@@ -721,7 +721,12 @@ class TemporalFusionTransformer(object):
encoder_steps = self.num_encoder_steps
# Inputs.
all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,))
all_inputs = tf.keras.layers.Input(
shape=(
time_steps,
combined_input_size,
)
)
unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)
@@ -861,7 +866,10 @@ class TemporalFusionTransformer(object):
"""Returns LSTM cell initialized with default parameters."""
if self.use_cudnn:
lstm = tf.keras.layers.CuDNNLSTM(
self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False,
self.hidden_layer_size,
return_sequences=True,
return_state=return_state,
stateful=False,
)
else:
lstm = tf.keras.layers.LSTM(