mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 11:00:57 +08:00
opt local trainer (better mem releasing) (#1116)
* opt local trainer (better mem releasing) * Update setup.py * Update data.py * fix CI
This commit is contained in:
@@ -75,6 +75,17 @@ class Config:
|
||||
def set_conf_from_C(self, config_c):
|
||||
self.update(**config_c.__dict__["_config"])
|
||||
|
||||
def register_from_C(self, config, skip_register=True):
|
||||
from .utils import set_log_with_config # pylint: disable=C0415
|
||||
|
||||
if C.registered and skip_register:
|
||||
return
|
||||
|
||||
C.set_conf_from_C(config)
|
||||
if C.logging_config:
|
||||
set_log_with_config(C.logging_config)
|
||||
C.register()
|
||||
|
||||
|
||||
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||
PROTOCOL_VERSION = 4
|
||||
|
||||
@@ -32,7 +32,6 @@ from ..utils import (
|
||||
hash_args,
|
||||
normalize_cache_fields,
|
||||
code_to_fname,
|
||||
set_log_with_config,
|
||||
time_to_slc_point,
|
||||
read_period_data,
|
||||
get_period_list,
|
||||
@@ -603,11 +602,7 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
# FIXME: Windows OS or MacOS using spawn: https://docs.python.org/3.8/library/multiprocessing.html?highlight=spawn#contexts-and-start-methods
|
||||
# NOTE: This place is compatible with windows, windows multi-process is spawn
|
||||
if not C.registered:
|
||||
C.set_conf_from_C(g_config)
|
||||
if C.logging_config:
|
||||
set_log_with_config(C.logging_config)
|
||||
C.register()
|
||||
C.register_from_C(g_config)
|
||||
|
||||
obj = dict()
|
||||
for field in column_names:
|
||||
|
||||
@@ -15,13 +15,22 @@ import socket
|
||||
from typing import Callable, List
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from qlib.config import C
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder
|
||||
from qlib.utils import (
|
||||
auto_filter_kwargs,
|
||||
fill_placeholder,
|
||||
flatten_dict,
|
||||
init_instance_by_config,
|
||||
)
|
||||
from qlib.utils.paral import call_in_subproc
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
|
||||
def _log_task_info(task_config: dict):
|
||||
@@ -210,17 +219,19 @@ class TrainerR(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
|
||||
def __init__(self, experiment_name: str = None, train_func: Callable = task_train, call_in_subproc: bool = False):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str, optional): the default name of experiment.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
call_in_subproc (bool): call the process in subprocess to force memory release
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.train_func = train_func
|
||||
self._call_in_subproc = call_in_subproc
|
||||
|
||||
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
@@ -245,6 +256,9 @@ class TrainerR(Trainer):
|
||||
experiment_name = self.experiment_name
|
||||
recs = []
|
||||
for task in tqdm(tasks, desc="train tasks"):
|
||||
if self._call_in_subproc:
|
||||
get_module_logger("TrainerR").info("running models in sub process (for forcing release memroy).")
|
||||
train_func = call_in_subproc(train_func, C)
|
||||
rec = train_func(task, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
|
||||
@@ -949,6 +949,10 @@ def auto_filter_kwargs(func: Callable, warning=True) -> Callable:
|
||||
|
||||
The decrated function will ignore and give warning when the parameter is not acceptable
|
||||
|
||||
For example, if you have a function `f` which may optionally consume the keywards `bar`.
|
||||
then you can call it by `auto_filter_kwargs(f)(bar=3)`, which will automatically filter out
|
||||
`bar` when f does not need bar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
|
||||
@@ -10,6 +10,9 @@ from joblib._parallel_backends import MultiprocessingBackend
|
||||
import pandas as pd
|
||||
|
||||
from queue import Queue
|
||||
import concurrent
|
||||
|
||||
from qlib.config import C, QlibConfig
|
||||
|
||||
|
||||
class ParallelExt(Parallel):
|
||||
@@ -273,3 +276,40 @@ def complex_parallel(paral: Parallel, complex_iter):
|
||||
dt.set_res(res)
|
||||
complex_iter = _recover_dt(complex_iter)
|
||||
return complex_iter
|
||||
|
||||
|
||||
class call_in_subproc:
|
||||
"""
|
||||
When we repeating run functions, it is hard to avoid memory leakage.
|
||||
So we run it in the subprocess to ensure it is OK.
|
||||
|
||||
NOTE: Because local object can't be pickled. So we can't implement it via closure.
|
||||
We have to implement it via callable Class
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable, qlib_config: QlibConfig = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
the function to be wrapped
|
||||
|
||||
qlib_config : QlibConfig
|
||||
Qlib config for initialization in subprocess
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
"""
|
||||
self.func = func
|
||||
self.qlib_config = qlib_config
|
||||
|
||||
def _func_mod(self, *args, **kwargs):
|
||||
"""Modify the initial function by adding Qlib initialization"""
|
||||
if self.qlib_config is not None:
|
||||
C.register_from_C(self.qlib_config)
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
return executor.submit(self._func_mod, *args, **kwargs).result()
|
||||
|
||||
Reference in New Issue
Block a user