1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 11:00:57 +08:00

opt local trainer (better mem releasing) (#1116)

* opt local trainer (better mem releasing)

* Update setup.py

* Update data.py

* fix CI
This commit is contained in:
you-n-g
2022-06-14 11:58:39 +08:00
committed by GitHub
parent e24ef67663
commit afcea404a5
5 changed files with 73 additions and 9 deletions

View File

@@ -75,6 +75,17 @@ class Config:
def set_conf_from_C(self, config_c):
self.update(**config_c.__dict__["_config"])
def register_from_C(self, config, skip_register=True):
from .utils import set_log_with_config # pylint: disable=C0415
if C.registered and skip_register:
return
C.set_conf_from_C(config)
if C.logging_config:
set_log_with_config(C.logging_config)
C.register()
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
PROTOCOL_VERSION = 4

View File

@@ -32,7 +32,6 @@ from ..utils import (
hash_args,
normalize_cache_fields,
code_to_fname,
set_log_with_config,
time_to_slc_point,
read_period_data,
get_period_list,
@@ -603,11 +602,7 @@ class DatasetProvider(abc.ABC):
"""
# FIXME: Windows OS or MacOS using spawn: https://docs.python.org/3.8/library/multiprocessing.html?highlight=spawn#contexts-and-start-methods
# NOTE: This place is compatible with windows, windows multi-process is spawn
if not C.registered:
C.set_conf_from_C(g_config)
if C.logging_config:
set_log_with_config(C.logging_config)
C.register()
C.register_from_C(g_config)
obj = dict()
for field in column_names:

View File

@@ -15,13 +15,22 @@ import socket
from typing import Callable, List
from tqdm.auto import tqdm
from qlib.config import C
from qlib.data.dataset import Dataset
from qlib.data.dataset.weight import Reweighter
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder
from qlib.utils import (
auto_filter_kwargs,
fill_placeholder,
flatten_dict,
init_instance_by_config,
)
from qlib.utils.paral import call_in_subproc
from qlib.workflow import R
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.data.dataset.weight import Reweighter
def _log_task_info(task_config: dict):
@@ -210,17 +219,19 @@ class TrainerR(Trainer):
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
def __init__(self, experiment_name: str = None, train_func: Callable = task_train, call_in_subproc: bool = False):
"""
Init TrainerR.
Args:
experiment_name (str, optional): the default name of experiment.
train_func (Callable, optional): default training method. Defaults to `task_train`.
call_in_subproc (bool): call the process in subprocess to force memory release
"""
super().__init__()
self.experiment_name = experiment_name
self.train_func = train_func
self._call_in_subproc = call_in_subproc
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
@@ -245,6 +256,9 @@ class TrainerR(Trainer):
experiment_name = self.experiment_name
recs = []
for task in tqdm(tasks, desc="train tasks"):
if self._call_in_subproc:
get_module_logger("TrainerR").info("running models in sub process (for forcing release memroy).")
train_func = call_in_subproc(train_func, C)
rec = train_func(task, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)

View File

@@ -949,6 +949,10 @@ def auto_filter_kwargs(func: Callable, warning=True) -> Callable:
The decrated function will ignore and give warning when the parameter is not acceptable
For example, if you have a function `f` which may optionally consume the keywards `bar`.
then you can call it by `auto_filter_kwargs(f)(bar=3)`, which will automatically filter out
`bar` when f does not need bar
Parameters
----------
func : Callable

View File

@@ -10,6 +10,9 @@ from joblib._parallel_backends import MultiprocessingBackend
import pandas as pd
from queue import Queue
import concurrent
from qlib.config import C, QlibConfig
class ParallelExt(Parallel):
@@ -273,3 +276,40 @@ def complex_parallel(paral: Parallel, complex_iter):
dt.set_res(res)
complex_iter = _recover_dt(complex_iter)
return complex_iter
class call_in_subproc:
"""
When we repeating run functions, it is hard to avoid memory leakage.
So we run it in the subprocess to ensure it is OK.
NOTE: Because local object can't be pickled. So we can't implement it via closure.
We have to implement it via callable Class
"""
def __init__(self, func: Callable, qlib_config: QlibConfig = None):
"""
Parameters
----------
func : Callable
the function to be wrapped
qlib_config : QlibConfig
Qlib config for initialization in subprocess
Returns
-------
Callable
"""
self.func = func
self.qlib_config = qlib_config
def _func_mod(self, *args, **kwargs):
"""Modify the initial function by adding Qlib initialization"""
if self.qlib_config is not None:
C.register_from_C(self.qlib_config)
return self.func(*args, **kwargs)
def __call__(self, *args, **kwargs):
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
return executor.submit(self._func_mod, *args, **kwargs).result()