diff --git a/qlib/config.py b/qlib/config.py index bee181133..8c54f6af9 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -75,6 +75,17 @@ class Config: def set_conf_from_C(self, config_c): self.update(**config_c.__dict__["_config"]) + def register_from_C(self, config, skip_register=True): + from .utils import set_log_with_config # pylint: disable=C0415 + + if C.registered and skip_register: + return + + C.set_conf_from_C(config) + if C.logging_config: + set_log_with_config(C.logging_config) + C.register() + # pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format PROTOCOL_VERSION = 4 diff --git a/qlib/data/data.py b/qlib/data/data.py index 7e602a17e..08320cae5 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -32,7 +32,6 @@ from ..utils import ( hash_args, normalize_cache_fields, code_to_fname, - set_log_with_config, time_to_slc_point, read_period_data, get_period_list, @@ -603,11 +602,7 @@ class DatasetProvider(abc.ABC): """ # FIXME: Windows OS or MacOS using spawn: https://docs.python.org/3.8/library/multiprocessing.html?highlight=spawn#contexts-and-start-methods # NOTE: This place is compatible with windows, windows multi-process is spawn - if not C.registered: - C.set_conf_from_C(g_config) - if C.logging_config: - set_log_with_config(C.logging_config) - C.register() + C.register_from_C(g_config) obj = dict() for field in column_names: diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 6de94b755..4bfa6337b 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -15,13 +15,22 @@ import socket from typing import Callable, List from tqdm.auto import tqdm + +from qlib.config import C from qlib.data.dataset import Dataset +from qlib.data.dataset.weight import Reweighter +from qlib.log import get_module_logger from qlib.model.base import Model -from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder +from qlib.utils import ( + auto_filter_kwargs, + fill_placeholder, + flatten_dict, + init_instance_by_config, +) +from qlib.utils.paral import call_in_subproc from qlib.workflow import R from qlib.workflow.recorder import Recorder from qlib.workflow.task.manage import TaskManager, run_task -from qlib.data.dataset.weight import Reweighter def _log_task_info(task_config: dict): @@ -210,17 +219,19 @@ class TrainerR(Trainer): STATUS_BEGIN = "begin_task_train" STATUS_END = "end_task_train" - def __init__(self, experiment_name: str = None, train_func: Callable = task_train): + def __init__(self, experiment_name: str = None, train_func: Callable = task_train, call_in_subproc: bool = False): """ Init TrainerR. Args: experiment_name (str, optional): the default name of experiment. train_func (Callable, optional): default training method. Defaults to `task_train`. + call_in_subproc (bool): call the process in subprocess to force memory release """ super().__init__() self.experiment_name = experiment_name self.train_func = train_func + self._call_in_subproc = call_in_subproc def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]: """ @@ -245,6 +256,9 @@ class TrainerR(Trainer): experiment_name = self.experiment_name recs = [] for task in tqdm(tasks, desc="train tasks"): + if self._call_in_subproc: + get_module_logger("TrainerR").info("running models in sub process (for forcing release memroy).") + train_func = call_in_subproc(train_func, C) rec = train_func(task, experiment_name, **kwargs) rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) recs.append(rec) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index c095acbb3..6334280ce 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -949,6 +949,10 @@ def auto_filter_kwargs(func: Callable, warning=True) -> Callable: The decrated function will ignore and give warning when the parameter is not acceptable + For example, if you have a function `f` which may optionally consume the keywards `bar`. + then you can call it by `auto_filter_kwargs(f)(bar=3)`, which will automatically filter out + `bar` when f does not need bar + Parameters ---------- func : Callable diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py index df9e530c0..439ca34b0 100644 --- a/qlib/utils/paral.py +++ b/qlib/utils/paral.py @@ -10,6 +10,9 @@ from joblib._parallel_backends import MultiprocessingBackend import pandas as pd from queue import Queue +import concurrent + +from qlib.config import C, QlibConfig class ParallelExt(Parallel): @@ -273,3 +276,40 @@ def complex_parallel(paral: Parallel, complex_iter): dt.set_res(res) complex_iter = _recover_dt(complex_iter) return complex_iter + + +class call_in_subproc: + """ + When we repeating run functions, it is hard to avoid memory leakage. + So we run it in the subprocess to ensure it is OK. + + NOTE: Because local object can't be pickled. So we can't implement it via closure. + We have to implement it via callable Class + """ + + def __init__(self, func: Callable, qlib_config: QlibConfig = None): + """ + Parameters + ---------- + func : Callable + the function to be wrapped + + qlib_config : QlibConfig + Qlib config for initialization in subprocess + + Returns + ------- + Callable + """ + self.func = func + self.qlib_config = qlib_config + + def _func_mod(self, *args, **kwargs): + """Modify the initial function by adding Qlib initialization""" + if self.qlib_config is not None: + C.register_from_C(self.qlib_config) + return self.func(*args, **kwargs) + + def __call__(self, *args, **kwargs): + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + return executor.submit(self._func_mod, *args, **kwargs).result()