1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00

Merge pull request #9 from wendili-cs/main

Update TFT
This commit is contained in:
you-n-g
2020-11-28 23:01:45 +08:00
committed by GitHub
3 changed files with 14 additions and 20 deletions

View File

@@ -194,10 +194,10 @@ class Alpha158Formatter(GenericDataFormatter):
"""Returns fixed model parameters for experiments."""
fixed_params = {
"total_time_steps": 16 + 6,
"num_encoder_steps": 16,
"total_time_steps": 6 + 6,
"num_encoder_steps": 6,
"num_epochs": 100,
"early_stopping_patience": 5,
"early_stopping_patience": 10,
"multiprocessing_workers": 5,
}
@@ -207,11 +207,11 @@ class Alpha158Formatter(GenericDataFormatter):
"""Returns default optimised model parameters."""
model_params = {
"dropout_rate": 0.3,
"hidden_layer_size": 160,
"learning_rate": 0.001,
"minibatch_size": 64,
"max_gradient_norm": 0.01,
"dropout_rate": 0.4,
"hidden_layer_size": 16,
"learning_rate": 0.0001,
"minibatch_size": 128,
"max_gradient_norm": 0.0135,
"num_heads": 1,
"stack_size": 1,
}

View File

@@ -721,12 +721,7 @@ class TemporalFusionTransformer(object):
encoder_steps = self.num_encoder_steps
# Inputs.
all_inputs = tf.keras.layers.Input(
shape=(
time_steps,
combined_input_size,
)
)
all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,))
unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)
@@ -866,10 +861,7 @@ class TemporalFusionTransformer(object):
"""Returns LSTM cell initialized with default parameters."""
if self.use_cudnn:
lstm = tf.keras.layers.CuDNNLSTM(
self.hidden_layer_size,
return_sequences=True,
return_state=return_state,
stateful=False,
self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False,
)
else:
lstm = tf.keras.layers.LSTM(

View File

@@ -82,7 +82,7 @@ def process_predicted(df, col_name):
"""
df_res = df.copy()
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+5": col_name})
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
df_res = df_res[[col_name]]
return df_res
@@ -232,7 +232,9 @@ class TFTModel(ModelFT):
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
tf.keras.backend.set_session(default_keras_session)
predict = format_score(p90_forecast, "pred", 0) # self.label_shift
predict50 = format_score(p50_forecast, "pred", 1)
predict90 = format_score(p90_forecast, "pred", 1)
predict = (predict50 + predict90) / 2 # self.label_shift
# ===========================Predicting Process===========================
return predict