diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py index 1cb14d261..1d444b104 100644 --- a/qlib/contrib/report/analysis_model/analysis_model_performance.py +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -3,7 +3,6 @@ import pandas as pd -import plotly.tools as tls import plotly.graph_objs as go import statsmodels.api as sm @@ -80,9 +79,37 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure: :param dist: :return: """ - fig, ax = plt.subplots(figsize=(8, 5)) - _mpl_fig = sm.qqplot(data.dropna(), dist, fit=True, line="45", ax=ax) - return tls.mpl_to_plotly(_mpl_fig) + _plt_fig = sm.qqplot(data.dropna(), dist=dist, fit=True, line="45") + plt.close(_plt_fig) + qqplot_data = _plt_fig.gca().lines + fig = go.Figure() + + fig.add_trace({ + 'type': 'scatter', + 'x': qqplot_data[0].get_xdata(), + # 'x': [0, 1], + 'y': qqplot_data[0].get_ydata(), + # 'y': [1, 2], + 'mode': 'markers', + 'marker': { + 'color': '#19d3f3' + } + }) + + fig.add_trace({ + 'type': 'scatter', + 'x': qqplot_data[1].get_xdata(), + # 'x': [0, 1], + 'y': qqplot_data[1].get_ydata(), + # 'y': [1, 2], + 'mode': 'lines', + 'line': { + 'color': '#636efa' + } + + }) + del qqplot_data + return fig def _pred_ic(pred_label: pd.DataFrame = None, rank: bool = False, **kwargs) -> tuple: