diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost.yaml index 574b52ddd..720599be5 100644 --- a/examples/benchmarks/CatBoost/workflow_config_catboost.yaml +++ b/examples/benchmarks/CatBoost/workflow_config_catboost.yaml @@ -32,7 +32,11 @@ task: loss: RMSE learning_rate: 0.0421 subsample: 0.8789 + max_depth: 6 + num_leaves: 100 thread_count: 20 + grow_policy: Lossguide + boostrap_type: Poisson dataset: class: DatasetH module_path: qlib.data.dataset diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py index 92d957d0f..a3b4fc919 100644 --- a/examples/benchmarks/TFT/tft.py +++ b/examples/benchmarks/TFT/tft.py @@ -232,7 +232,7 @@ class TFTModel(ModelFT): p90_forecast = self.data_formatter.format_predictions(output_map["p90"]) tf.keras.backend.set_session(default_keras_session) - predict = format_score(p90_forecast, "pred", 0) # self.label_shift + predict = format_score(p90_forecast, "pred", 0) # self.label_shift label = format_score(targets, "label", 0) # ===========================Predicting Process=========================== return predict, label diff --git a/examples/run_all_model.py b/examples/run_all_model.py index f8894afd3..b448a1857 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -178,8 +178,7 @@ def get_all_folders() -> dict: folders = dict() for f in os.scandir("benchmarks"): path = Path("benchmarks") / f.name - if f.name != "TFT": - folders[f.name] = str(path.resolve()) + folders[f.name] = str(path.resolve()) return folders diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 404313e80..db6b1440d 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -21,18 +21,16 @@ class DataLoader(abc.ABC): @abc.abstractmethod def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame: """ - load the data as pd.DataFrame + load the data as pd.DataFrame. Parameters ---------- - self : [TODO:type] - [TODO:description] - instruments : [TODO:type] - [TODO:description] - start_time : [TODO:type] - [TODO:description] - end_time : [TODO:type] - [TODO:description] + instruments : str or dict + it can either be the market name or the config file of instruments generated by InstrumentProvider. + start_time : str + start of the time range. + end_time : str + end of the time range. Returns -------