mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Plot enhancement (#1390)
* horizontally put the bar figures * 1) use rangebreaks to handle gaps in datetime axis instead of make them string; 2) allow simultaneously plot rankic in ic_figure * pylint improvement * fix black lint * better axis formatting * default not show gaps * resolve doc built error * fix pylint * Update qlib/contrib/report/analysis_model/analysis_model_performance.py More detailed description Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> * Update qlib/contrib/report/analysis_model/analysis_model_performance.py for Python backward compatibility Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> * add doc string * fix black * 1) limit numpy version as numba support for 1.24+ has not been released; 2) no need to use custom numba version for pytest. Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -10,7 +11,11 @@ import matplotlib.pyplot as plt
|
||||
|
||||
from scipy import stats
|
||||
|
||||
from typing import Sequence
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from ..graph import ScatterGraph, SubplotsGraph, BarGraph, HeatmapGraph
|
||||
from ..utils import guess_plotly_rangebreaks
|
||||
|
||||
|
||||
def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs) -> tuple:
|
||||
@@ -48,12 +53,13 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
|
||||
t_df["long-average"] = t_df["Group1"] - pred_label.groupby(level="datetime")["label"].mean()
|
||||
|
||||
t_df = t_df.dropna(how="all") # for days which does not contain label
|
||||
# FIXME: support HIGH-FREQ
|
||||
t_df.index = t_df.index.strftime("%Y-%m-%d")
|
||||
# Cumulative Return By Group
|
||||
group_scatter_figure = ScatterGraph(
|
||||
t_df.cumsum(),
|
||||
layout=dict(title="Cumulative Return", xaxis=dict(type="category", tickangle=45)),
|
||||
layout=dict(
|
||||
title="Cumulative Return",
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(t_df.index))),
|
||||
),
|
||||
).figure
|
||||
|
||||
t_df = t_df.loc[:, ["long-short", "long-average"]]
|
||||
@@ -110,22 +116,36 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure:
|
||||
return fig
|
||||
|
||||
|
||||
def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple:
|
||||
def _pred_ic(
|
||||
pred_label: pd.DataFrame = None, methods: Sequence[Literal["IC", "Rank IC"]] = ("IC", "Rank IC"), **kwargs
|
||||
) -> tuple:
|
||||
"""
|
||||
|
||||
:param pred_label:
|
||||
:param rank:
|
||||
:param pred_label: pd.DataFrame
|
||||
must contain one column of realized return with name `label` and one column of predicted score names `score`.
|
||||
:param methods: Sequence[Literal["IC", "Rank IC"]]
|
||||
IC series to plot.
|
||||
IC is sectional pearson correlation between label and score
|
||||
Rank IC is the spearman correlation between label and score
|
||||
For the Monthly IC, IC histogram, IC Q-Q plot. Only the first type of IC will be plotted.
|
||||
:return:
|
||||
"""
|
||||
if rank:
|
||||
ic = pred_label.groupby(level="datetime").apply(
|
||||
lambda x: x["label"].rank(pct=True).corr(x["score"].rank(pct=True))
|
||||
)
|
||||
else:
|
||||
ic = pred_label.groupby(level="datetime").apply(lambda x: x["label"].corr(x["score"]))
|
||||
_methods_mapping = {"IC": "pearson", "Rank IC": "spearman"}
|
||||
|
||||
_index = ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
|
||||
_monthly_ic = ic.groupby(_index).mean()
|
||||
def _corr_series(x, method):
|
||||
return x["label"].corr(x["score"], method=method)
|
||||
|
||||
ic_df = pd.concat(
|
||||
[
|
||||
pred_label.groupby(level="datetime").apply(partial(_corr_series, method=_methods_mapping[m])).rename(m)
|
||||
for m in methods
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
_ic = ic_df.iloc(axis=1)[0]
|
||||
|
||||
_index = _ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6)
|
||||
_monthly_ic = _ic.groupby(_index).mean()
|
||||
_monthly_ic.index = pd.MultiIndex.from_arrays(
|
||||
[_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)],
|
||||
names=["year", "month"],
|
||||
@@ -148,27 +168,27 @@ def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> t
|
||||
|
||||
_monthly_ic = _monthly_ic.reindex(fill_index)
|
||||
|
||||
_ic_df = ic.to_frame("ic")
|
||||
ic_bar_figure = ic_figure(_ic_df, kwargs.get("show_nature_day", True))
|
||||
ic_bar_figure = ic_figure(ic_df, kwargs.get("show_nature_day", False))
|
||||
|
||||
ic_heatmap_figure = HeatmapGraph(
|
||||
_monthly_ic.unstack(),
|
||||
layout=dict(title="Monthly IC", yaxis=dict(tickformat=",d")),
|
||||
layout=dict(title="Monthly IC", xaxis=dict(dtick=1), yaxis=dict(tickformat="04d", dtick=1)),
|
||||
graph_kwargs=dict(xtype="array", ytype="array"),
|
||||
).figure
|
||||
|
||||
dist = stats.norm
|
||||
_qqplot_fig = _plot_qq(ic, dist)
|
||||
_qqplot_fig = _plot_qq(_ic, dist)
|
||||
|
||||
if isinstance(dist, stats.norm.__class__):
|
||||
dist_name = "Normal"
|
||||
else:
|
||||
dist_name = "Unknown"
|
||||
|
||||
_ic_df = _ic.to_frame("IC")
|
||||
_bin_size = ((_ic_df.max() - _ic_df.min()) / 20).min()
|
||||
_sub_graph_data = [
|
||||
(
|
||||
"ic",
|
||||
"IC",
|
||||
dict(
|
||||
row=1,
|
||||
col=1,
|
||||
@@ -202,12 +222,13 @@ def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple:
|
||||
pred = pred_label.copy()
|
||||
pred["score_last"] = pred.groupby(level="instrument")["score"].shift(lag)
|
||||
ac = pred.groupby(level="datetime").apply(lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True)))
|
||||
# FIXME: support HIGH-FREQ
|
||||
_df = ac.to_frame("value")
|
||||
_df.index = _df.index.strftime("%Y-%m-%d")
|
||||
ac_figure = ScatterGraph(
|
||||
_df,
|
||||
layout=dict(title="Auto Correlation", xaxis=dict(type="category", tickangle=45)),
|
||||
layout=dict(
|
||||
title="Auto Correlation",
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(_df.index))),
|
||||
),
|
||||
).figure
|
||||
return (ac_figure,)
|
||||
|
||||
@@ -233,32 +254,33 @@ def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple:
|
||||
"Bottom": bottom,
|
||||
}
|
||||
)
|
||||
# FIXME: support HIGH-FREQ
|
||||
r_df.index = r_df.index.strftime("%Y-%m-%d")
|
||||
turnover_figure = ScatterGraph(
|
||||
r_df,
|
||||
layout=dict(title="Top-Bottom Turnover", xaxis=dict(type="category", tickangle=45)),
|
||||
layout=dict(
|
||||
title="Top-Bottom Turnover",
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(r_df.index))),
|
||||
),
|
||||
).figure
|
||||
return (turnover_figure,)
|
||||
|
||||
|
||||
def ic_figure(ic_df: pd.DataFrame, show_nature_day=True, **kwargs) -> go.Figure:
|
||||
"""IC figure
|
||||
r"""IC figure
|
||||
|
||||
:param ic_df: ic DataFrame
|
||||
:param show_nature_day: whether to display the abscissa of non-trading day
|
||||
:param \*\*kwargs: contains some parameters to control plot style in plotly. Currently, supports
|
||||
- `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
|
||||
:return: plotly.graph_objs.Figure
|
||||
"""
|
||||
if show_nature_day:
|
||||
date_index = pd.date_range(ic_df.index.min(), ic_df.index.max())
|
||||
ic_df = ic_df.reindex(date_index)
|
||||
# FIXME: support HIGH-FREQ
|
||||
ic_df.index = ic_df.index.strftime("%Y-%m-%d")
|
||||
ic_bar_figure = BarGraph(
|
||||
ic_df,
|
||||
layout=dict(
|
||||
title="Information Coefficient (IC)",
|
||||
xaxis=dict(type="category", tickangle=45),
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(ic_df.index))),
|
||||
),
|
||||
).figure
|
||||
return ic_bar_figure
|
||||
@@ -272,9 +294,10 @@ def model_performance_graph(
|
||||
rank=False,
|
||||
graph_names: list = ["group_return", "pred_ic", "pred_autocorr"],
|
||||
show_notebook: bool = True,
|
||||
show_nature_day=True,
|
||||
show_nature_day: bool = False,
|
||||
**kwargs,
|
||||
) -> [list, tuple]:
|
||||
"""Model performance
|
||||
r"""Model performance
|
||||
|
||||
:param pred_label: index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[score, label]**.
|
||||
It is usually same as the label of model training(e.g. "Ref($close, -2)/Ref($close, -1) - 1").
|
||||
@@ -297,17 +320,14 @@ def model_performance_graph(
|
||||
:param graph_names: graph names; default ['cumulative_return', 'pred_ic', 'pred_autocorr', 'pred_turnover'].
|
||||
:param show_notebook: whether to display graphics in notebook, the default is `True`.
|
||||
:param show_nature_day: whether to display the abscissa of non-trading day.
|
||||
:param \*\*kwargs: contains some parameters to control plot style in plotly. Currently, supports
|
||||
- `rangebreaks`: https://plotly.com/python/time-series/#Hiding-Weekends-and-Holidays
|
||||
:return: if show_notebook is True, display in notebook; else return `plotly.graph_objs.Figure` list.
|
||||
"""
|
||||
figure_list = []
|
||||
for graph_name in graph_names:
|
||||
fun_res = eval(f"_{graph_name}")(
|
||||
pred_label=pred_label,
|
||||
lag=lag,
|
||||
N=N,
|
||||
reverse=reverse,
|
||||
rank=rank,
|
||||
show_nature_day=show_nature_day,
|
||||
pred_label=pred_label, lag=lag, N=N, reverse=reverse, rank=rank, show_nature_day=show_nature_day, **kwargs
|
||||
)
|
||||
figure_list += fun_res
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]:
|
||||
_figure = SubplotsGraph(
|
||||
_get_all_risk_analysis(analysis_df),
|
||||
kind_map=dict(kind="BarGraph", kwargs={}),
|
||||
subplots_kwargs={"rows": 4, "cols": 1},
|
||||
subplots_kwargs={"rows": 1, "cols": 4},
|
||||
).figure
|
||||
return (_figure,)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import pandas as pd
|
||||
|
||||
from ..graph import ScatterGraph
|
||||
from ..utils import guess_plotly_rangebreaks
|
||||
|
||||
|
||||
def _get_score_ic(pred_label: pd.DataFrame):
|
||||
@@ -19,7 +20,7 @@ def _get_score_ic(pred_label: pd.DataFrame):
|
||||
return pd.DataFrame({"ic": _ic, "rank_ic": _rank_ic})
|
||||
|
||||
|
||||
def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [list, tuple]:
|
||||
def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True, **kwargs) -> [list, tuple]:
|
||||
"""score IC
|
||||
|
||||
Example:
|
||||
@@ -53,11 +54,13 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True) -> [lis
|
||||
:return: if show_notebook is True, display in notebook; else return **plotly.graph_objs.Figure** list.
|
||||
"""
|
||||
_ic_df = _get_score_ic(pred_label)
|
||||
# FIXME: support HIGH-FREQ
|
||||
_ic_df.index = _ic_df.index.strftime("%Y-%m-%d")
|
||||
|
||||
_figure = ScatterGraph(
|
||||
_ic_df,
|
||||
layout=dict(title="Score IC", xaxis=dict(type="category", tickangle=45)),
|
||||
layout=dict(
|
||||
title="Score IC",
|
||||
xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(_ic_df.index))),
|
||||
),
|
||||
graph_kwargs={"mode": "lines+markers"},
|
||||
).figure
|
||||
if show_notebook:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
|
||||
@@ -43,3 +44,31 @@ def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None
|
||||
res = res.item()
|
||||
yield res
|
||||
plt.show()
|
||||
|
||||
|
||||
def guess_plotly_rangebreaks(dt_index: pd.DatetimeIndex):
|
||||
"""
|
||||
This function `guesses` the rangebreaks required to remove gaps in datetime index.
|
||||
It basically calculates the difference between a `continuous` datetime index and index given.
|
||||
|
||||
For more details on `rangebreaks` params in plotly, see
|
||||
https://plotly.com/python/reference/layout/xaxis/#layout-xaxis-rangebreaks
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dt_index: pd.DatetimeIndex
|
||||
The datetimes of the data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
the `rangebreaks` to be passed into plotly axis.
|
||||
|
||||
"""
|
||||
dt_idx = dt_index.sort_values()
|
||||
gaps = dt_idx[1:] - dt_idx[:-1]
|
||||
min_gap = gaps.min()
|
||||
gaps_to_break = {}
|
||||
for gap, d in zip(gaps, dt_idx[:-1]):
|
||||
if gap > min_gap:
|
||||
gaps_to_break.setdefault(gap - min_gap, []).append(d + min_gap)
|
||||
return [dict(values=v, dvalue=int(k.total_seconds() * 1000)) for k, v in gaps_to_break.items()]
|
||||
|
||||
2
setup.py
2
setup.py
@@ -44,7 +44,7 @@ if not _CYTHON_INSTALLED:
|
||||
# What packages are required for this module to be executed?
|
||||
# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here.
|
||||
REQUIRED = [
|
||||
"numpy>=1.12.0",
|
||||
"numpy>=1.12.0, <1.24",
|
||||
"pandas>=0.25.1",
|
||||
"scipy>=1.0.0",
|
||||
"requests>=2.18.0",
|
||||
|
||||
Reference in New Issue
Block a user