1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

replace multi processing with joblib (#477)

* replace multi processing with joblib

* update class Parallel and data.py

* update class Parallel and data.py

* update class Parallel and data.py

* update class Parallel and data.py

* update class Parallel and data.py

* update class Parallel and data.py

* update class Parallel and data.py

* update class Parallel and data.py

* Fix Parallel support for maxtasksperchild

Co-authored-by: wangw <1666490690@qq.com>
Co-authored-by: zhupr <zhu.pengrong@foxmail.com>
This commit is contained in:
you-n-g
2021-09-14 01:16:03 +08:00
committed by GitHub
parent 6203e4c09e
commit 163e3c6266
4 changed files with 86 additions and 95 deletions

View File

@@ -92,6 +92,8 @@ _default_config = {
"kernels": NUM_USABLE_CPU,
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
# If joblib_backend is None, use loky
"joblib_backend": "multiprocessing",
"default_disk_cache": 1, # 0:skip/1:use
"mem_cache_size_limit": 500,
# memory cache expire second, only in used 'DatasetURICache' and 'client D.calendar'

View File

@@ -9,16 +9,15 @@ import os
import re
import abc
import copy
import time
import queue
import bisect
import logging
import importlib
import traceback
from typing import List, Union
import numpy as np
import pandas as pd
from multiprocessing import Pool
# For supporting multiprocessing in outter code, joblib is used
from joblib import delayed
from .cache import H
from ..config import C
@@ -29,6 +28,7 @@ from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
from ..utils.resam import resam_calendar
from ..utils.paral import ParallelExt
class ProviderBackendMixin:
@@ -418,16 +418,7 @@ class DatasetProvider(abc.ABC):
"""
raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method")
def _uri(
self,
instruments,
fields,
start_time=None,
end_time=None,
freq="day",
disk_cache=1,
**kwargs,
):
def _uri(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, **kwargs):
"""Get task uri, used when generating rabbitmq task in qlib_server
Parameters
@@ -494,51 +485,37 @@ class DatasetProvider(abc.ABC):
"""
normalize_column_names = normalize_cache_fields(column_names)
data = dict()
# One process for one task, so that the memory will be freed quicker.
workers = max(min(C.kernels, len(instruments_d)), 1)
if C.maxtasksperchild is None:
p = Pool(processes=workers)
else:
p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild)
# create iterator
if isinstance(instruments_d, dict):
for inst, spans in instruments_d.items():
data[inst] = p.apply_async(
DatasetProvider.expression_calculator,
args=(
inst,
start_time,
end_time,
freq,
normalize_column_names,
spans,
C,
),
)
it = instruments_d.items()
else:
for inst in instruments_d:
data[inst] = p.apply_async(
DatasetProvider.expression_calculator,
args=(
inst,
start_time,
end_time,
freq,
normalize_column_names,
None,
C,
),
)
it = zip(instruments_d, [None] * len(instruments_d))
p.close()
p.join()
inst_l = []
task_l = []
for inst, spans in it:
inst_l.append(inst)
task_l.append(
delayed(DatasetProvider.expression_calculator)(
inst, start_time, end_time, freq, normalize_column_names, spans, C
)
)
data = dict(
zip(
inst_l,
ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(task_l),
)
)
new_data = dict()
for inst in sorted(data.keys()):
if len(data[inst].get()) > 0:
if len(data[inst]) > 0:
# NOTE: Python version >= 3.6; in versions after python3.6, dict will always guarantee the insertion order
new_data[inst] = data[inst].get()
new_data[inst] = data[inst]
if len(new_data) > 0:
data = pd.concat(new_data, names=["instrument"], sort=False)
@@ -755,25 +732,11 @@ class LocalDatasetProvider(DatasetProvider):
start_time = cal[0]
end_time = cal[-1]
workers = max(min(C.kernels, len(instruments_d)), 1)
if C.maxtasksperchild is None:
p = Pool(processes=workers)
else:
p = Pool(processes=workers, maxtasksperchild=C.maxtasksperchild)
for inst in instruments_d:
p.apply_async(
LocalDatasetProvider.cache_walker,
args=(
inst,
start_time,
end_time,
freq,
column_names,
),
)
p.close()
p.join()
ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(
delayed(LocalDatasetProvider.cache_walker)(inst, start_time, end_time, freq, column_names)
for inst in instruments_d
)
@staticmethod
def cache_walker(inst, start_time, end_time, freq, column_names):
@@ -803,12 +766,7 @@ class ClientCalendarProvider(CalendarProvider):
self.conn.send_request(
request_type="calendar",
request_content={
"start_time": str(start_time),
"end_time": str(end_time),
"freq": freq,
"future": future,
},
request_content={"start_time": str(start_time), "end_time": str(end_time), "freq": freq, "future": future},
msg_queue=self.queue,
msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content],
)
@@ -871,16 +829,7 @@ class ClientDatasetProvider(DatasetProvider):
self.conn = conn
self.queue = queue.Queue()
def dataset(
self,
instruments,
fields,
start_time=None,
end_time=None,
freq="day",
disk_cache=0,
return_uri=False,
):
def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, return_uri=False):
if Inst.get_inst_type(instruments) == Inst.DICT:
get_module_logger("data").warning(
"Getting features from a dict of instruments is not recommended because the features will not be "
@@ -984,15 +933,7 @@ class BaseProvider:
def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False):
return Inst.list_instruments(instruments, start_time, end_time, freq, as_list)
def features(
self,
instruments,
fields,
start_time=None,
end_time=None,
freq="day",
disk_cache=None,
):
def features(self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=None):
"""
Parameters:
-----------

View File

@@ -1,8 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from joblib import Parallel, delayed
import pandas as pd
from joblib import Parallel, delayed
from joblib._parallel_backends import MultiprocessingBackend
class ParallelExt(Parallel):
def __init__(self, *args, **kwargs):
maxtasksperchild = kwargs.pop("maxtasksperchild", None)
super(ParallelExt, self).__init__(*args, **kwargs)
if isinstance(self._backend, MultiprocessingBackend):
self._backend_args["maxtasksperchild"] = maxtasksperchild
def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_rule="M", n_jobs=-1, skip_group=False):
@@ -31,7 +40,7 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
return df.groupby(axis=axis, level=level).apply(apply_func)
if n_jobs != 1:
dfs = Parallel(n_jobs=n_jobs)(
dfs = ParallelExt(n_jobs=n_jobs)(
delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, axis=axis, level=level)
)
return pd.concat(dfs, axis=axis).sort_index()

View File

@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
import qlib
from qlib.data import D
from qlib.tests import TestAutoData
from multiprocessing import Pool
def get_features(fields):
qlib.init(provider_uri=TestAutoData.provider_uri, expression_cache=None, dataset_cache=None, joblib_backend="loky")
return D.features(D.instruments("csi300"), fields)
class TestGetData(TestAutoData):
FIELDS = "$open,$close,$high,$low,$volume,$factor,$change".split(",")
def test_multi_proc(self):
"""
For testing if it will raise error
"""
iter_n = 2
pool = Pool(iter_n)
res = []
for _ in range(iter_n):
res.append(pool.apply_async(get_features, (self.FIELDS,), {}))
for r in res:
print(r.get())
pool.close()
pool.join()
if __name__ == "__main__":
unittest.main()