diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py index 92d957d0f..a3b4fc919 100644 --- a/examples/benchmarks/TFT/tft.py +++ b/examples/benchmarks/TFT/tft.py @@ -232,7 +232,7 @@ class TFTModel(ModelFT): p90_forecast = self.data_formatter.format_predictions(output_map["p90"]) tf.keras.backend.set_session(default_keras_session) - predict = format_score(p90_forecast, "pred", 0) # self.label_shift + predict = format_score(p90_forecast, "pred", 0) # self.label_shift label = format_score(targets, "label", 0) # ===========================Predicting Process=========================== return predict, label diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 404313e80..db6b1440d 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -21,18 +21,16 @@ class DataLoader(abc.ABC): @abc.abstractmethod def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: """ - load the data as pd.DataFrame + load the data as pd.DataFrame. Parameters ---------- - self : [TODO:type] - [TODO:description] - instruments : [TODO:type] - [TODO:description] - start_time : [TODO:type] - [TODO:description] - end_time : [TODO:type] - [TODO:description] + instruments : str or dict + it can either be the market name or the config file of instruments generated by InstrumentProvider. + start_time : str + start of the time range. + end_time : str + end of the time range. Returns -------