mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Add data analysis feature for report (#918)
* Add data analysis feature for report * better display
This commit is contained in:
@@ -4,8 +4,10 @@ Here is a batch of evaluation functions.
|
||||
The interface should be redesigned carefully in the future.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple
|
||||
from qlib import get_module_logger
|
||||
from qlib.utils.paral import complex_parallel, DelayedDict
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
@@ -61,32 +63,6 @@ def calc_long_short_prec(
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_long_short_return(
|
||||
pred: pd.Series,
|
||||
label: pd.Series,
|
||||
@@ -127,3 +103,105 @@ def calc_long_short_return(
|
||||
r_short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label.mean())
|
||||
r_avg = group.label.mean()
|
||||
return (r_long - r_short) / 2, r_avg
|
||||
|
||||
|
||||
def pred_autocorr(pred: pd.Series, lag=1, inst_col="instrument", date_col="datetime"):
|
||||
"""pred_autocorr.
|
||||
|
||||
Limitation:
|
||||
- If the datetime is not sequential densely, the correlation will be calulated based on adjacent dates. (some users may expected NaN)
|
||||
|
||||
:param pred: pd.Series with following format
|
||||
instrument datetime
|
||||
SH600000 2016-01-04 -0.000403
|
||||
2016-01-05 -0.000753
|
||||
2016-01-06 -0.021801
|
||||
2016-01-07 -0.065230
|
||||
2016-01-08 -0.062465
|
||||
:type pred: pd.Series
|
||||
:param lag:
|
||||
"""
|
||||
if isinstance(pred, pd.DataFrame):
|
||||
pred = pred.iloc[:, 0]
|
||||
get_module_logger("pred_autocorr").warning("Only the first column in {pred.columns} of `pred` is kept")
|
||||
pred_ustk = pred.sort_index().unstack(inst_col)
|
||||
corr_s = {}
|
||||
for (idx, cur), (_, prev) in zip(pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows()):
|
||||
corr_s[idx] = cur.corr(prev)
|
||||
corr_s = pd.Series(corr_s).sort_index()
|
||||
return corr_s
|
||||
|
||||
|
||||
def pred_autocorr_all(pred_dict, n_jobs=-1, **kwargs):
|
||||
"""
|
||||
calculate auto correlation for pred_dict
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred_dict : dict
|
||||
A dict like {<method_name>: <prediction>}
|
||||
kwargs :
|
||||
all these arguments will be passed into pred_autocorr
|
||||
"""
|
||||
ac_dict = {}
|
||||
for k, pred in pred_dict.items():
|
||||
ac_dict[k] = delayed(pred_autocorr)(pred, **kwargs)
|
||||
return complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), ac_dict)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> (pd.Series, pd.Series):
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_all_ic(pred_dict_all, label, date_col="datetime", dropna=False, n_jobs=-1):
|
||||
"""calc_all_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred_dict_all :
|
||||
A dict like {<method_name>: <prediction>}
|
||||
label:
|
||||
A pd.Series of label values
|
||||
|
||||
Returns
|
||||
-------
|
||||
{'Q2+IND_z': {'ic': <ic series like>
|
||||
2016-01-04 -0.057407
|
||||
...
|
||||
2020-05-28 0.183470
|
||||
2020-05-29 0.171393
|
||||
'ric': <rank ic series like>
|
||||
2016-01-04 -0.040888
|
||||
...
|
||||
2020-05-28 0.236665
|
||||
2020-05-29 0.183886
|
||||
}
|
||||
...}
|
||||
"""
|
||||
pred_all_ics = {}
|
||||
for k, pred in pred_dict_all.items():
|
||||
pred_all_ics[k] = DelayedDict(["ic", "ric"], delayed(calc_ic)(pred, label, date_col=date_col, dropna=dropna))
|
||||
pred_all_ics = complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), pred_all_ics)
|
||||
return pred_all_ics
|
||||
|
||||
@@ -74,7 +74,7 @@ class DNNModelPytorch(Model):
|
||||
data_parall=False,
|
||||
scheduler: Optional[Union[Callable]] = "default", # when it is Callable, it accept one argument named optimizer
|
||||
init_model=None,
|
||||
eval_train_metric=True,
|
||||
eval_train_metric=False,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_nn.Net",
|
||||
pt_model_kwargs={
|
||||
"input_dim": 360,
|
||||
@@ -290,7 +290,7 @@ class DNNModelPytorch(Model):
|
||||
)
|
||||
R.log_metrics(train_metric=metric_train, step=step)
|
||||
else:
|
||||
metric_train = -1
|
||||
metric_train = np.nan
|
||||
if verbose:
|
||||
self.logger.info(
|
||||
f"[Step {step}]: train_loss {train_loss:.6f}, valid_loss {loss_val:.6f}, train_metric {metric_train:.6f}, valid_metric {metric_val:.6f}"
|
||||
|
||||
7
qlib/contrib/report/data/__init__.py
Normal file
7
qlib/contrib/report/data/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This module is designed to analysis data
|
||||
|
||||
"""
|
||||
202
qlib/contrib/report/data/ana.py
Normal file
202
qlib/contrib/report/data/ana.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.contrib.report.data.base import FeaAnalyser
|
||||
from qlib.contrib.report.utils import sub_fig_generator
|
||||
from qlib.utils.paral import datetime_groupby_apply
|
||||
from qlib.contrib.eva.alpha import pred_autocorr_all
|
||||
from loguru import logger
|
||||
import seaborn as sns
|
||||
|
||||
DT_COL_NAME = "datetime"
|
||||
|
||||
|
||||
class CombFeaAna(FeaAnalyser):
|
||||
"""
|
||||
Combine the sub feature analysers and plot then in a single graph
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: pd.DataFrame, *fea_ana_cls):
|
||||
if len(fea_ana_cls) <= 1:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
self._fea_ana_l = [fcls(dataset) for fcls in fea_ana_cls]
|
||||
super().__init__(dataset=dataset)
|
||||
|
||||
def skip(self, col):
|
||||
return np.all(list(map(lambda fa: fa.skip(col), self._fea_ana_l)))
|
||||
|
||||
def calc_stat_values(self):
|
||||
"""The statistics of features are finished in the underlying analysers"""
|
||||
|
||||
def plot_all(self, *args, **kwargs):
|
||||
|
||||
ax_gen = iter(sub_fig_generator(row_n=len(self._fea_ana_l), *args, **kwargs))
|
||||
|
||||
for col in self._dataset:
|
||||
if not self.skip(col):
|
||||
axes = next(ax_gen)
|
||||
for fa, ax in zip(self._fea_ana_l, axes):
|
||||
if not fa.skip(col):
|
||||
fa.plot_single(col, ax)
|
||||
ax.set_xlabel("")
|
||||
ax.set_title("")
|
||||
axes[0].set_title(col)
|
||||
|
||||
|
||||
class NumFeaAnalyser(FeaAnalyser):
|
||||
def skip(self, col):
|
||||
is_obj = np.issubdtype(self._dataset[col], np.dtype("O"))
|
||||
if is_obj:
|
||||
logger.info(f"{col} is not numeric and is skipped")
|
||||
return is_obj
|
||||
|
||||
|
||||
class ValueCNT(FeaAnalyser):
|
||||
def __init__(self, dataset: pd.DataFrame, ratio=False):
|
||||
self.ratio = ratio
|
||||
super().__init__(dataset)
|
||||
|
||||
def calc_stat_values(self):
|
||||
self._val_cnt = {}
|
||||
for col, item in self._dataset.items():
|
||||
if not super().skip(col):
|
||||
self._val_cnt[col] = item.groupby(DT_COL_NAME).apply(lambda s: len(s.unique()))
|
||||
self._val_cnt = pd.DataFrame(self._val_cnt)
|
||||
if self.ratio:
|
||||
self._val_cnt = self._val_cnt.div(self._dataset.groupby(DT_COL_NAME).size(), axis=0)
|
||||
|
||||
# TODO: transfer this feature to other analysers
|
||||
ymin, ymax = self._val_cnt.min().min(), self._val_cnt.max().max()
|
||||
self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._val_cnt[col].plot(ax=ax, title=col, ylim=self.ylim)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaDistAna(NumFeaAnalyser):
|
||||
def plot_single(self, col, ax):
|
||||
sns.histplot(self._dataset[col], ax=ax, kde=False, bins=100)
|
||||
ax.set_xlabel("")
|
||||
ax.set_title(col)
|
||||
|
||||
|
||||
class FeaInfAna(NumFeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._inf_cnt = {}
|
||||
for col, item in self._dataset.items():
|
||||
if not super().skip(col):
|
||||
self._inf_cnt[col] = item.apply(np.isinf).astype(np.int).groupby(DT_COL_NAME).sum()
|
||||
self._inf_cnt = pd.DataFrame(self._inf_cnt)
|
||||
|
||||
def skip(self, col):
|
||||
return (col not in self._inf_cnt) or (self._inf_cnt[col].sum() == 0)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._inf_cnt[col].plot(ax=ax, title=col)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaNanAna(FeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME).sum()
|
||||
|
||||
def skip(self, col):
|
||||
return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._nan_cnt[col].plot(ax=ax, title=col)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaNanAnaRatio(FeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME).sum()
|
||||
self._total_cnt = self._dataset.groupby(DT_COL_NAME).size()
|
||||
|
||||
def skip(self, col):
|
||||
return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
(self._nan_cnt[col] / self._total_cnt).plot(ax=ax, title=col)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaACAna(FeaAnalyser):
|
||||
"""Analysis the auto-correlation of features"""
|
||||
|
||||
def calc_stat_values(self):
|
||||
self._fea_corr = pred_autocorr_all(self._dataset.to_dict("series"))
|
||||
df = pd.DataFrame(self._fea_corr)
|
||||
ymin, ymax = df.min().min(), df.max().max()
|
||||
self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._fea_corr[col].plot(ax=ax, title=col, ylim=self.ylim)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaSkewTurt(NumFeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._skew = datetime_groupby_apply(self._dataset, "skew", skip_group=True)
|
||||
self._kurt = datetime_groupby_apply(self._dataset, pd.DataFrame.kurt, skip_group=True)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._skew[col].plot(ax=ax, label="skew")
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("skew")
|
||||
ax.legend()
|
||||
|
||||
right_ax = ax.twinx()
|
||||
|
||||
self._kurt[col].plot(ax=right_ax, label="kurt", color="green")
|
||||
right_ax.set_xlabel("")
|
||||
right_ax.set_ylabel("kurt")
|
||||
|
||||
h1, l1 = ax.get_legend_handles_labels()
|
||||
h2, l2 = right_ax.get_legend_handles_labels()
|
||||
|
||||
ax.legend().set_visible(False)
|
||||
right_ax.legend(h1 + h2, l1 + l2)
|
||||
ax.set_title(col)
|
||||
|
||||
|
||||
class FeaMeanStd(NumFeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._std = self._dataset.groupby(DT_COL_NAME).std()
|
||||
self._mean = self._dataset.groupby(DT_COL_NAME).mean()
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._mean[col].plot(ax=ax, label="mean")
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("mean")
|
||||
ax.legend()
|
||||
|
||||
right_ax = ax.twinx()
|
||||
|
||||
self._std[col].plot(ax=right_ax, label="std", color="green")
|
||||
right_ax.set_xlabel("")
|
||||
right_ax.set_ylabel("std")
|
||||
|
||||
h1, l1 = ax.get_legend_handles_labels()
|
||||
h2, l2 = right_ax.get_legend_handles_labels()
|
||||
|
||||
ax.legend().set_visible(False)
|
||||
right_ax.legend(h1 + h2, l1 + l2)
|
||||
ax.set_title(col)
|
||||
|
||||
|
||||
class RawFeaAna(FeaAnalyser):
|
||||
"""
|
||||
Motivation:
|
||||
- display the values without further analysis
|
||||
"""
|
||||
|
||||
def calc_stat_values(self):
|
||||
ymin, ymax = self._dataset.min().min(), self._dataset.max().max()
|
||||
self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._dataset[col].plot(ax=ax, title=col, ylim=self.ylim)
|
||||
ax.set_xlabel("")
|
||||
36
qlib/contrib/report/data/base.py
Normal file
36
qlib/contrib/report/data/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This module is responsible for analysing data
|
||||
|
||||
Assumptions
|
||||
- The analyse each feature individually
|
||||
|
||||
"""
|
||||
import pandas as pd
|
||||
from blocks.utils.log import logt
|
||||
from qlib.contrib.report.utils import sub_fig_generator
|
||||
|
||||
|
||||
class FeaAnalyser:
|
||||
def __init__(self, dataset: pd.DataFrame):
|
||||
self._dataset = dataset
|
||||
with logt("calc_stat_values"):
|
||||
self.calc_stat_values()
|
||||
|
||||
def calc_stat_values(self):
|
||||
pass
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def skip(self, col):
|
||||
return False
|
||||
|
||||
def plot_all(self, *args, **kwargs):
|
||||
|
||||
ax_gen = iter(sub_fig_generator(*args, **kwargs))
|
||||
for col in self._dataset:
|
||||
if not self.skip(col):
|
||||
ax = next(ax_gen)
|
||||
self.plot_single(col, ax)
|
||||
45
qlib/contrib/report/utils.py
Normal file
45
qlib/contrib/report/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
|
||||
"""sub_fig_generator.
|
||||
it will return a generator, each row contains <col_n> sub graph
|
||||
|
||||
FIXME: Known limitation:
|
||||
- The last row will not be plotted automatically, please plot it outside the function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sub_fs :
|
||||
the figure size of each subgraph in <col_n> * <row_n> subgraphs
|
||||
col_n :
|
||||
the number of subgraph in each row; It will generating a new graph after generating <col_n> of subgraphs.
|
||||
row_n :
|
||||
the number of subgraph in each column
|
||||
wspace :
|
||||
the width of the space for subgraphs in each row
|
||||
hspace :
|
||||
the height of blank space for subgraphs in each column
|
||||
You can try 0.3 if you feel it is too crowded
|
||||
|
||||
Returns
|
||||
-------
|
||||
It will return graphs with the shape of <col_n> each iter (it is squeezed).
|
||||
"""
|
||||
assert col_n > 1
|
||||
|
||||
while True:
|
||||
fig, axes = plt.subplots(
|
||||
row_n, col_n, figsize=(sub_fs[0] * col_n, sub_fs[1] * row_n), sharex=sharex, sharey=sharey
|
||||
)
|
||||
plt.subplots_adjust(wspace=wspace, hspace=hspace)
|
||||
axes = axes.reshape(row_n, col_n)
|
||||
|
||||
for col in range(col_n):
|
||||
res = axes[:, col].squeeze()
|
||||
if res.size == 1:
|
||||
res = res.item()
|
||||
yield res
|
||||
plt.show()
|
||||
@@ -63,7 +63,7 @@ def fetch_df_by_index(
|
||||
Data of the given index.
|
||||
"""
|
||||
# level = None -> use selector directly
|
||||
if level is None:
|
||||
if level is None or isinstance(selector, pd.MultiIndex):
|
||||
return df.loc(axis=0)[selector]
|
||||
# Try to get the right index
|
||||
idx_slc = (selector, slice(None, None))
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from functools import partial
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
from typing import Callable, Text, Union
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
@@ -20,7 +20,9 @@ class ParallelExt(Parallel):
|
||||
self._backend_args["maxtasksperchild"] = maxtasksperchild
|
||||
|
||||
|
||||
def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_rule="M", n_jobs=-1, skip_group=False):
|
||||
def datetime_groupby_apply(
|
||||
df, apply_func: Union[Callable, Text], axis=0, level="datetime", resample_rule="M", n_jobs=-1, skip_group=False
|
||||
):
|
||||
"""datetime_groupby_apply
|
||||
This function will apply the `apply_func` on the datetime level index.
|
||||
|
||||
@@ -28,8 +30,9 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
|
||||
----------
|
||||
df :
|
||||
DataFrame for processing
|
||||
apply_func :
|
||||
apply_func : Union[Callable, Text]
|
||||
apply_func for processing the data
|
||||
if a string is given, then it is treated as naive pandas function
|
||||
axis :
|
||||
which axis is the datetime level located
|
||||
level :
|
||||
@@ -43,6 +46,8 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
|
||||
"""
|
||||
|
||||
def _naive_group_apply(df):
|
||||
if isinstance(apply_func, str):
|
||||
return getattr(df.groupby(axis=axis, level=level), apply_func)()
|
||||
return df.groupby(axis=axis, level=level).apply(apply_func)
|
||||
|
||||
if n_jobs != 1:
|
||||
@@ -102,3 +107,169 @@ class AsyncCaller:
|
||||
return wrapper
|
||||
|
||||
return decorator_func
|
||||
|
||||
|
||||
# # Outlines: Joblib enhancement
|
||||
# The code are for implementing following workflow
|
||||
# - Construct complex data structure nested with delayed joblib tasks
|
||||
# - For example, {"job": [<delayed_joblib_task>, {"1": <delayed_joblib_task>}]}
|
||||
# - executing all the tasks and replace all the <deplayed_joblib_task> with its return value
|
||||
|
||||
# This will make it easier to convert some existing code to a parallel one
|
||||
|
||||
|
||||
class DelayedTask:
|
||||
def get_delayed_tuple(self):
|
||||
"""get_delayed_tuple.
|
||||
Return the delayed_tuple created by joblib.delayed
|
||||
"""
|
||||
raise NotImplementedError("NotImplemented")
|
||||
|
||||
def set_res(self, res):
|
||||
"""set_res.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
res :
|
||||
the executed result of the delayed tuple
|
||||
"""
|
||||
self.res = res
|
||||
|
||||
def get_replacement(self):
|
||||
"""return the object to replace the delayed task"""
|
||||
raise NotImplementedError("NotImplemented")
|
||||
|
||||
|
||||
class DelayedTuple(DelayedTask):
|
||||
def __init__(self, delayed_tpl):
|
||||
self.delayed_tpl = delayed_tpl
|
||||
self.res = None
|
||||
|
||||
def get_delayed_tuple(self):
|
||||
return self.delayed_tpl
|
||||
|
||||
def get_replacement(self):
|
||||
return self.res
|
||||
|
||||
|
||||
class DelayedDict(DelayedTask):
|
||||
"""DelayedDict.
|
||||
It is designed for following feature:
|
||||
Converting following existing code to parallel
|
||||
- constructing a dict
|
||||
- key can be get instantly
|
||||
- computation of values tasks a lot of time.
|
||||
- AND ALL the values are calculated in a SINGLE function
|
||||
"""
|
||||
|
||||
def __init__(self, key_l, delayed_tpl):
|
||||
self.key_l = key_l
|
||||
self.delayed_tpl = delayed_tpl
|
||||
|
||||
def get_delayed_tuple(self):
|
||||
return self.delayed_tpl
|
||||
|
||||
def get_replacement(self):
|
||||
return dict(zip(self.key_l, self.res))
|
||||
|
||||
|
||||
def is_delayed_tuple(obj) -> bool:
|
||||
"""is_delayed_tuple.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : object
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
is `obj` joblib.delayed tuple
|
||||
"""
|
||||
return isinstance(obj, tuple) and len(obj) == 3 and callable(obj[0])
|
||||
|
||||
|
||||
def _replace_and_get_dt(complex_iter):
|
||||
"""_replace_and_get_dt.
|
||||
|
||||
FIXME: this function may cause infinite loop when the complex data-structure contains loop-reference
|
||||
|
||||
Parameters
|
||||
----------
|
||||
complex_iter :
|
||||
complex_iter
|
||||
"""
|
||||
if isinstance(complex_iter, DelayedTask):
|
||||
dt = complex_iter
|
||||
return dt, [dt]
|
||||
elif is_delayed_tuple(complex_iter):
|
||||
dt = DelayedTuple(complex_iter)
|
||||
return dt, [dt]
|
||||
elif isinstance(complex_iter, (list, tuple)):
|
||||
new_ci = []
|
||||
dt_all = []
|
||||
for item in complex_iter:
|
||||
new_item, dt_list = _replace_and_get_dt(item)
|
||||
new_ci.append(new_item)
|
||||
dt_all += dt_list
|
||||
return new_ci, dt_all
|
||||
elif isinstance(complex_iter, dict):
|
||||
new_ci = {}
|
||||
dt_all = []
|
||||
for key, item in complex_iter.items():
|
||||
new_item, dt_list = _replace_and_get_dt(item)
|
||||
new_ci[key] = new_item
|
||||
dt_all += dt_list
|
||||
return new_ci, dt_all
|
||||
else:
|
||||
return complex_iter, []
|
||||
|
||||
|
||||
def _recover_dt(complex_iter):
|
||||
"""_recover_dt.
|
||||
|
||||
replace all the DelayedTask in the `complex_iter` with its `.res` value
|
||||
|
||||
FIXME: this function may cause infinite loop when the complex data-structure contains loop-reference
|
||||
|
||||
Parameters
|
||||
----------
|
||||
complex_iter :
|
||||
complex_iter
|
||||
"""
|
||||
if isinstance(complex_iter, DelayedTask):
|
||||
return complex_iter.get_replacement()
|
||||
elif isinstance(complex_iter, (list, tuple)):
|
||||
return [_recover_dt(item) for item in complex_iter]
|
||||
elif isinstance(complex_iter, dict):
|
||||
return {key: _recover_dt(item) for key, item in complex_iter.items()}
|
||||
else:
|
||||
return complex_iter
|
||||
|
||||
|
||||
def complex_parallel(paral: Parallel, complex_iter):
|
||||
"""complex_parallel.
|
||||
Find all the delayed function created by delayed in complex_iter, run them parallelly and then replace it with the result
|
||||
|
||||
>>> from qlib.utils.paral import complex_parallel
|
||||
>>> from joblib import Parallel, delayed
|
||||
>>> complex_iter = {"a": delayed(sum)([1,2,3]), "b": [1, 2, delayed(sum)([10, 1])]}
|
||||
>>> complex_parallel(Parallel(), complex_iter)
|
||||
{'a': 6, 'b': [1, 2, 11]}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
paral : Parallel
|
||||
paral
|
||||
complex_iter :
|
||||
NOTE: only list, tuple and dict will be explored!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
complex_iter whose delayed joblib tasks are replaced with its execution results.
|
||||
"""
|
||||
|
||||
complex_iter, dt_all = _replace_and_get_dt(complex_iter)
|
||||
for res, dt in zip(paral(dt.get_delayed_tuple() for dt in dt_all), dt_all):
|
||||
dt.set_res(res)
|
||||
complex_iter = _recover_dt(complex_iter)
|
||||
return complex_iter
|
||||
|
||||
Reference in New Issue
Block a user