diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py index 714cfdd9c..6d108cabf 100644 --- a/qlib/contrib/report/analysis_position/report.py +++ b/qlib/contrib/report/analysis_position/report.py @@ -100,13 +100,13 @@ def _report_figure(df: pd.DataFrame) -> [list, tuple]: ("cum_ex_return_wo_cost_mdd", dict(row=7, col=1, graph_kwargs=_temp_fill_args)), ] - _subplot_layout = dict( - xaxis=dict(showline=True, type="category", tickangle=45), - yaxis=dict(zeroline=True, showline=True, showticklabels=True), - ) - for i in range(2, 8): + _subplot_layout = dict() + for i in range(1, 8): # yaxis _subplot_layout.update({"yaxis{}".format(i): dict(zeroline=True, showline=True, showticklabels=True)}) + _show_line = i == 7 + _subplot_layout.update({"xaxis{}".format(i): dict(showline=_show_line, type="category", tickangle=45)}) + _layout_style = dict( height=1200, title=" ", diff --git a/qlib/contrib/report/analysis_position/risk_analysis.py b/qlib/contrib/report/analysis_position/risk_analysis.py index 89650c39e..124a9b3b0 100644 --- a/qlib/contrib/report/analysis_position/risk_analysis.py +++ b/qlib/contrib/report/analysis_position/risk_analysis.py @@ -116,7 +116,11 @@ def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]: if analysis_df is None: return [] - _figure = SubplotsGraph(_get_all_risk_analysis(analysis_df), kind_map=dict(kind="BarGraph", kwargs={})).figure + _figure = SubplotsGraph( + _get_all_risk_analysis(analysis_df), + kind_map=dict(kind="BarGraph", kwargs={}), + subplots_kwargs={"rows": 4, "cols": 1}, + ).figure return (_figure,) diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index 07ed94f90..15cc5fd0e 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -125,7 +125,10 @@ class BaseGraph(object): :return: """ - return go.Figure(data=self.data, layout=self._get_layout()) + _figure = go.Figure(data=self.data, layout=self._get_layout()) + # NOTE: using default 3.x theme + _figure["layout"].update(template=None) + return _figure class ScatterGraph(BaseGraph): @@ -363,7 +366,8 @@ class SubplotsGraph(object): for k, v in self._sub_graph_layout.items(): self._figure["layout"][k].update(v) - self._figure["layout"].update(self._layout) + # NOTE: using default 3.x theme + self._figure["layout"].update(self._layout, template=None) @property def figure(self): diff --git a/setup.py b/setup.py index d08e378cb..7c2688666 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ REQUIRED = [ "fire>=0.2.1", "statsmodels", "xlrd>=1.0.0", - "plotly==3.5.0", + "plotly==4.12.0", "matplotlib==3.1.3", "tables>=3.6.1", "pyyaml>=5.3.1",