1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Update TFT

This commit is contained in:
v-blin
2020-11-28 14:55:16 +00:00
parent 30ab4a8d8b
commit fdf0f9a182
2 changed files with 3 additions and 11 deletions

View File

@@ -721,12 +721,7 @@ class TemporalFusionTransformer(object):
encoder_steps = self.num_encoder_steps
# Inputs.
all_inputs = tf.keras.layers.Input(
shape=(
time_steps,
combined_input_size,
)
)
all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,))
unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)
@@ -866,10 +861,7 @@ class TemporalFusionTransformer(object):
"""Returns LSTM cell initialized with default parameters."""
if self.use_cudnn:
lstm = tf.keras.layers.CuDNNLSTM(
self.hidden_layer_size,
return_sequences=True,
return_state=return_state,
stateful=False,
self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False,
)
else:
lstm = tf.keras.layers.LSTM(

View File

@@ -234,7 +234,7 @@ class TFTModel(ModelFT):
predict50 = format_score(p50_forecast, "pred", 1)
predict90 = format_score(p90_forecast, "pred", 1)
predict = (predict50 + predict90)/2 # self.label_shift
predict = (predict50 + predict90) / 2 # self.label_shift
# ===========================Predicting Process===========================
return predict