mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Update tft.py
This commit is contained in:
@@ -82,7 +82,7 @@ def process_predicted(df, col_name):
|
||||
|
||||
"""
|
||||
df_res = df.copy()
|
||||
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+5": col_name})
|
||||
df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name})
|
||||
df_res = df_res.set_index(["datetime", "instrument"]).sort_index()
|
||||
df_res = df_res[[col_name]]
|
||||
return df_res
|
||||
@@ -232,7 +232,9 @@ class TFTModel(ModelFT):
|
||||
p90_forecast = self.data_formatter.format_predictions(output_map["p90"])
|
||||
tf.keras.backend.set_session(default_keras_session)
|
||||
|
||||
predict = format_score(p90_forecast, "pred", 0) # self.label_shift
|
||||
predict50 = format_score(p50_forecast, "pred", 1)
|
||||
predict90 = format_score(p90_forecast, "pred", 1)
|
||||
predict = (predict50 + predict90)/2 # self.label_shift
|
||||
# ===========================Predicting Process===========================
|
||||
return predict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user