diff --git a/qlib/__init__.py b/qlib/__init__.py index 816e5a585..4fd48f8c1 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -154,7 +154,7 @@ def init_from_yaml_conf(conf_path, **kwargs): init(default_conf, **config) -def get_project_path(config_name="config.yaml") -> Path: +def get_project_path(config_name="config.yaml", cur_path=None) -> Path: """ If users are building a project follow the following pattern. - Qlib is a sub folder in project path @@ -181,7 +181,8 @@ def get_project_path(config_name="config.yaml") -> Path: FileNotFoundError: If project path is not found """ - cur_path = Path(__file__).absolute().resolve() + if cur_path is None: + cur_path = Path(__file__).absolute().resolve() while True: if (cur_path / config_name).exists(): return cur_path @@ -199,7 +200,7 @@ def auto_init(**kwargs): """ try: - pp = get_project_path() + pp = get_project_path(cur_path=kwargs.pop("cur_path", None)) except FileNotFoundError: init(**kwargs) else: diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 5e62a141c..a4df92218 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -34,7 +34,7 @@ def task_train(task_config: dict, experiment_name: str) -> str: model.fit(dataset) recorder = R.get_recorder() R.save_objects(**{"params.pkl": model}) - R.save_objects(**{"task.pkl": task_config}) # keep the original format and datatype + R.save_objects(task=task_config) # keep the original format and datatype # generate records: prediction, backtest, and analysis records = task_config.get("record", []) diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index b16312ff7..b4a584494 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,5 +1,6 @@ from qlib.workflow import R import pandas as pd +import tqdm.auto import tqdm from typing import Union from qlib import get_module_logger @@ -35,7 +36,7 @@ class TaskCollector: recs_flt = {} for rid, rec in recs.items(): - params = rec.load_object("task.pkl") + params = rec.load_object("task") if rec.status == rec.STATUS_FI: if filter_func is None or filter_func(params): rec.params = params @@ -83,7 +84,7 @@ class RollingCollector: recs_flt = {} for rid, rec in tqdm(recs.items(), desc="Loading data"): - params = rec.load_object("task.pkl") + params = rec.load_object("task") if rec.status == rec.STATUS_FI: if self.flt_func is None or self.flt_func(params): rec.params = params diff --git a/qlib/workflow/task/update.py b/qlib/workflow/task/update.py index f9d03efbc..628225a20 100644 --- a/qlib/workflow/task/update.py +++ b/qlib/workflow/task/update.py @@ -74,7 +74,7 @@ class ModelUpdater: rec = self.exp.get_recorder(recorder_id=rid) old_pred = rec.load_object("pred.pkl") last_end = old_pred.index.get_level_values("datetime").max() - task_config = rec.load_object("task.pkl") + task_config = rec.load_object("task") # updated to the latest trading day cal = D.calendar(start_time=last_end + pd.Timedelta(days=1), end_time=None) @@ -107,7 +107,7 @@ class ModelUpdater: .. code-block:: python def record_filter(record): - task_config = record.load_object("task.pkl") + task_config = record.load_object("task") if task_config["model"]["class"]=="LGBModel": return True return False