From b4efbd53b2f8889b984a5f283e8d62cd3ecf1976 Mon Sep 17 00:00:00 2001 From: zhupr Date: Wed, 16 Jun 2021 22:00:43 +0800 Subject: [PATCH] Fix 'report' compatibility with matplotlib versions --- .../analysis_model_performance.py | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py index 1cb14d261..1d444b104 100644 --- a/qlib/contrib/report/analysis_model/analysis_model_performance.py +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -3,7 +3,6 @@ import pandas as pd -import plotly.tools as tls import plotly.graph_objs as go import statsmodels.api as sm @@ -80,9 +79,37 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure: :param dist: :return: """ - fig, ax = plt.subplots(figsize=(8, 5)) - _mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax) - return tls.mpl_to_plotly(_mpl_fig) + _plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45") + plt.close(_plt_fig) + qqplot_data = _plt_fig.gca().lines + fig = go.Figure() + + fig.add_trace({ + 'type': 'scatter', + 'x': qqplot_data[0].get_xdata(), + # 'x': [0, 1], + 'y': qqplot_data[0].get_ydata(), + # 'y': [1, 2], + 'mode': 'markers', + 'marker': { + 'color': '#19d3f3' + } + }) + + fig.add_trace({ + 'type': 'scatter', + 'x': qqplot_data[1].get_xdata(), + # 'x': [0, 1], + 'y': qqplot_data[1].get_ydata(), + # 'y': [1, 2], + 'mode': 'lines', + 'line': { + 'color': '#636efa' + } + + }) + del qqplot_data + return fig def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple: