From e333b786507a53b3ad99023b412c5eff77bbde04 Mon Sep 17 00:00:00 2001 From: Wendi Li Date: Sat, 28 Nov 2020 22:51:22 +0800 Subject: [PATCH] Update tft.py --- examples/benchmarks/TFT/tft.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py index 3387a5947..b5398e7f2 100644 --- a/examples/benchmarks/TFT/tft.py +++ b/examples/benchmarks/TFT/tft.py @@ -82,7 +82,7 @@ def process_predicted(df, col_name): """ df_res = df.copy() - df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+5": col_name}) + df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name}) df_res = df_res.set_index(["datetime", "instrument"]).sort_index() df_res = df_res[[col_name]] return df_res @@ -232,7 +232,9 @@ class TFTModel(ModelFT): p90_forecast = self.data_formatter.format_predictions(output_map["p90"]) tf.keras.backend.set_session(default_keras_session) - predict = format_score(p90_forecast, "pred", 0) # self.label_shift + predict50 = format_score(p50_forecast, "pred", 1) + predict90 = format_score(p90_forecast, "pred", 1) + predict = (predict50 + predict90)/2 # self.label_shift # ===========================Predicting Process=========================== return predict