1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Update tft.py

This commit is contained in:
Wendi Li
2020-11-25 17:31:15 +08:00
committed by you-n-g
parent 9b7251d8d4
commit 068ad4ba90

View File

@@ -82,7 +82,7 @@ def process_predicted(df, col_name):
"""
df_res = df.copy()
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+0": col_name})
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+5": col_name})
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
df_res = df_res[[col_name]]
return df_res
@@ -232,8 +232,8 @@ class TFTModel(ModelFT):
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
tf.keras.backend.set_session(default_keras_session)
predict = format_score(p90_forecast, "pred", self.label_shift)
label = format_score(targets, "label", self.label_shift)
predict = format_score(p90_forecast, "pred", 0) # self.label_shift
label = format_score(targets, "label", 0)
# ===========================Predicting Process===========================
return predict, label