mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
replace multi processing with joblib (#477)
* replace multi processing with joblib * update class Parallel and data.py * update class Parallel and data.py * update class Parallel and data.py * update class Parallel and data.py * update class Parallel and data.py * update class Parallel and data.py * update class Parallel and data.py * update class Parallel and data.py * Fix Parallel support for maxtasksperchild Co-authored-by: wangw <1666490690@qq.com> Co-authored-by: zhupr <zhu.pengrong@foxmail.com>
This commit is contained in:
@@ -92,6 +92,8 @@ _default_config = {
|
||||
"kernels": NUM_USABLE_CPU,
|
||||
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
|
||||
"maxtasksperchild": None,
|
||||
# If joblib_backend is None, use loky
|
||||
"joblib_backend": "multiprocessing",
|
||||
"default_disk_cache": 1, # 0:skip/1:use
|
||||
"mem_cache_size_limit": 500,
|
||||
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'
|
||||
|
||||
@@ -9,16 +9,15 @@ import os
|
||||
import re
|
||||
import abc
|
||||
import copy
|
||||
import time
|
||||
import queue
|
||||
import bisect
|
||||
import logging
|
||||
import importlib
|
||||
import traceback
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
|
||||
# For supporting multiprocessing in outter code, joblib is used
|
||||
from joblib import delayed
|
||||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
@@ -29,6 +28,7 @@ from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils.resam import resam_calendar
|
||||
from ..utils.paral import ParallelExt
|
||||
|
||||
|
||||
class ProviderBackendMixin:
|
||||
@@ -418,16 +418,7 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")
|
||||
|
||||
def _uri(
|
||||
self,
|
||||
instruments,
|
||||
fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=1,
|
||||
**kwargs,
|
||||
):
|
||||
def _uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, **kwargs):
|
||||
"""Get task uri, used when generating rabbitmq task in qlib_server
|
||||
|
||||
Parameters
|
||||
@@ -494,51 +485,37 @@ class DatasetProvider(abc.ABC):
|
||||
|
||||
"""
|
||||
normalize_column_names = normalize_cache_fields(column_names)
|
||||
data = dict()
|
||||
# One process for one task, so that the memory will be freed quicker.
|
||||
workers = max(min(C.kernels, len(instruments_d)), 1)
|
||||
|
||||
if C.maxtasksperchild is None:
|
||||
p = Pool(processes=workers)
|
||||
else:
|
||||
p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild)
|
||||
# create iterator
|
||||
if isinstance(instruments_d, dict):
|
||||
for inst, spans in instruments_d.items():
|
||||
data[inst] = p.apply_async(
|
||||
DatasetProvider.expression_calculator,
|
||||
args=(
|
||||
inst,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
normalize_column_names,
|
||||
spans,
|
||||
C,
|
||||
),
|
||||
)
|
||||
it = instruments_d.items()
|
||||
else:
|
||||
for inst in instruments_d:
|
||||
data[inst] = p.apply_async(
|
||||
DatasetProvider.expression_calculator,
|
||||
args=(
|
||||
inst,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
normalize_column_names,
|
||||
None,
|
||||
C,
|
||||
),
|
||||
)
|
||||
it = zip(instruments_d, [None] * len(instruments_d))
|
||||
|
||||
p.close()
|
||||
p.join()
|
||||
inst_l = []
|
||||
task_l = []
|
||||
for inst, spans in it:
|
||||
inst_l.append(inst)
|
||||
task_l.append(
|
||||
delayed(DatasetProvider.expression_calculator)(
|
||||
inst, start_time, end_time, freq, normalize_column_names, spans, C
|
||||
)
|
||||
)
|
||||
|
||||
data = dict(
|
||||
zip(
|
||||
inst_l,
|
||||
ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(task_l),
|
||||
)
|
||||
)
|
||||
|
||||
new_data = dict()
|
||||
for inst in sorted(data.keys()):
|
||||
if len(data[inst].get()) > 0:
|
||||
if len(data[inst]) > 0:
|
||||
# NOTE: Python version >= 3.6; in versions after python3.6, dict will always guarantee the insertion order
|
||||
new_data[inst] = data[inst].get()
|
||||
new_data[inst] = data[inst]
|
||||
|
||||
if len(new_data) > 0:
|
||||
data = pd.concat(new_data, names=["instrument"], sort=False)
|
||||
@@ -755,25 +732,11 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
workers = max(min(C.kernels, len(instruments_d)), 1)
|
||||
if C.maxtasksperchild is None:
|
||||
p = Pool(processes=workers)
|
||||
else:
|
||||
p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild)
|
||||
|
||||
for inst in instruments_d:
|
||||
p.apply_async(
|
||||
LocalDatasetProvider.cache_walker,
|
||||
args=(
|
||||
inst,
|
||||
start_time,
|
||||
end_time,
|
||||
freq,
|
||||
column_names,
|
||||
),
|
||||
)
|
||||
|
||||
p.close()
|
||||
p.join()
|
||||
ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(
|
||||
delayed(LocalDatasetProvider.cache_walker)(inst, start_time, end_time, freq, column_names)
|
||||
for inst in instruments_d
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cache_walker(inst, start_time, end_time, freq, column_names):
|
||||
@@ -803,12 +766,7 @@ class ClientCalendarProvider(CalendarProvider):
|
||||
|
||||
self.conn.send_request(
|
||||
request_type="calendar",
|
||||
request_content={
|
||||
"start_time": str(start_time),
|
||||
"end_time": str(end_time),
|
||||
"freq": freq,
|
||||
"future": future,
|
||||
},
|
||||
request_content={"start_time": str(start_time), "end_time": str(end_time), "freq": freq, "future": future},
|
||||
msg_queue=self.queue,
|
||||
msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content],
|
||||
)
|
||||
@@ -871,16 +829,7 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
self.conn = conn
|
||||
self.queue = queue.Queue()
|
||||
|
||||
def dataset(
|
||||
self,
|
||||
instruments,
|
||||
fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=0,
|
||||
return_uri=False,
|
||||
):
|
||||
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, return_uri=False):
|
||||
if Inst.get_inst_type(instruments) == Inst.DICT:
|
||||
get_module_logger("data").warning(
|
||||
"Getting features from a dict of instruments is not recommended because the features will not be "
|
||||
@@ -984,15 +933,7 @@ class BaseProvider:
|
||||
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
|
||||
return Inst.list_instruments(instruments, start_time, end_time, freq, as_list)
|
||||
|
||||
def features(
|
||||
self,
|
||||
instruments,
|
||||
fields,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
freq="day",
|
||||
disk_cache=None,
|
||||
):
|
||||
def features(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=None):
|
||||
"""
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
import pandas as pd
|
||||
from joblib import Parallel, delayed
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
|
||||
|
||||
class ParallelExt(Parallel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
maxtasksperchild = kwargs.pop("maxtasksperchild", None)
|
||||
super(ParallelExt, self).__init__(*args, **kwargs)
|
||||
if isinstance(self._backend, MultiprocessingBackend):
|
||||
self._backend_args["maxtasksperchild"] = maxtasksperchild
|
||||
|
||||
|
||||
def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_rule="M", n_jobs=-1, skip_group=False):
|
||||
@@ -31,7 +40,7 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
|
||||
return df.groupby(axis=axis, level=level).apply(apply_func)
|
||||
|
||||
if n_jobs != 1:
|
||||
dfs = Parallel(n_jobs=n_jobs)(
|
||||
dfs = ParallelExt(n_jobs=n_jobs)(
|
||||
delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, axis=axis, level=level)
|
||||
)
|
||||
return pd.concat(dfs, axis=axis).sort_index()
|
||||
|
||||
39
tests/misc/test_get_multi_proc.py
Normal file
39
tests/misc/test_get_multi_proc.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.tests import TestAutoData
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
def get_features(fields):
|
||||
qlib.init(provider_uri=TestAutoData.provider_uri, expression_cache=None, dataset_cache=None, joblib_backend="loky")
|
||||
return D.features(D.instruments("csi300"), fields)
|
||||
|
||||
|
||||
class TestGetData(TestAutoData):
|
||||
FIELDS = "$open,$close,$high,$low,$volume,$factor,$change".split(",")
|
||||
|
||||
def test_multi_proc(self):
|
||||
"""
|
||||
For testing if it will raise error
|
||||
"""
|
||||
iter_n = 2
|
||||
pool = Pool(iter_n)
|
||||
|
||||
res = []
|
||||
for _ in range(iter_n):
|
||||
res.append(pool.apply_async(get_features, (self.FIELDS,), {}))
|
||||
|
||||
for r in res:
|
||||
print(r.get())
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user