diff --git a/qlib/__init__.py b/qlib/__init__.py index 816e5a585..4fd48f8c1 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -154,7 +154,7 @@ def init_from_yaml_conf(conf_path, **kwargs): init(default_conf, **config) -def get_project_path(config_name="config.yaml") -> Path: +def get_project_path(config_name="config.yaml", cur_path=None) -> Path: """ If users are building a project follow the following pattern. - Qlib is a sub folder in project path @@ -181,7 +181,8 @@ def get_project_path(config_name="config.yaml") -> Path: FileNotFoundError: If project path is not found """ - cur_path = Path(__file__).absolute().resolve() + if cur_path is None: + cur_path = Path(__file__).absolute().resolve() while True: if (cur_path / config_name).exists(): return cur_path @@ -199,7 +200,7 @@ def auto_init(**kwargs): """ try: - pp = get_project_path() + pp = get_project_path(cur_path=kwargs.pop("cur_path", None)) except FileNotFoundError: init(**kwargs) else: diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 2889c4465..25d02fdf6 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -50,6 +50,9 @@ class DataHandler(Serializable): SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042 SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289 + + Tips for improving the performance of datahandler + - Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc` """ def __init__( @@ -257,6 +260,10 @@ class DataHandler(Serializable): class DataHandlerLP(DataHandler): """ DataHandler with **(L)earnable (P)rocessor** + + Tips to improving the performance of data handler + - To reduce the memory cost + - `drop_raw=True`: this will modify the data inplace on raw data; """ # data key diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index c18145073..e6a0c1592 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -36,7 +36,7 @@ def task_train(task_config: dict, experiment_name: str) -> str: model.fit(dataset) recorder = R.get_recorder() R.save_objects(**{"params.pkl": model}) - R.save_objects(**{"task": task_config}) # keep the original format and datatype + R.save_objects(task=task_config) # keep the original format and datatype artifact_uri = recorder.get_artifact_uri()[7:] # delete "file://" dataset.to_pickle(artifact_uri + "/dataset", exclude=["handler"]) diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index fb7ff0b0b..86624a439 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,5 +1,6 @@ from qlib.workflow import R import pandas as pd +from tqdm.auto import tqdm from typing import Union from typing import Callable @@ -113,4 +114,4 @@ class RollingCollector(TaskCollector): pred = pred.sort_index() reduce_group[k] = pred - return reduce_group \ No newline at end of file + return reduce_group diff --git a/qlib/workflow/task/update.py b/qlib/workflow/task/update.py index fcee84349..441522018 100644 --- a/qlib/workflow/task/update.py +++ b/qlib/workflow/task/update.py @@ -96,8 +96,8 @@ class ModelUpdater: .. code-block:: python - def rec_filter_func(recorder): - task_config = recorder.load_object("task") + def record_filter(record): + task_config = record.load_object("task") if task_config["model"]["class"]=="LGBModel": return True return False