• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ricequant / rqalpha / 26139152694

20 May 2026 03:19AM UTC coverage: 68.183% (+0.2%) from 68.024%
26139152694

push

github

web-flow
Merge pull request #1007 from ricequant/develop

Develop

197 of 272 new or added lines in 15 files covered. (72.43%)

169 existing lines in 7 files now uncovered.

7642 of 11208 relevant lines covered (68.18%)

5.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

28.8
/rqalpha/mod/rqalpha_mod_sys_analyser/plot/plot.py
1
# -*- coding: utf-8 -*-
2
# 版权所有 2021 深圳米筐科技有限公司(下称“米筐科技”)
3
#
4
# 除非遵守当前许可,否则不得使用本软件。
5
#
6
#     * 非商业用途(非商业用途指个人出于非商业目的使用本软件,或者高校、研究所等非营利机构出于教育、科研等目的使用本软件):
7
#         遵守 Apache License 2.0(下称“Apache 2.0 许可”),
8
#         您可以在以下位置获得 Apache 2.0 许可的副本:http://www.apache.org/licenses/LICENSE-2.0。
9
#         除非法律有要求或以书面形式达成协议,否则本软件分发时需保持当前许可“原样”不变,且不得附加任何条件。
10
#
11
#     * 商业用途(商业用途指个人出于任何商业目的使用本软件,或者法人或其他组织出于任何目的使用本软件):
12
#         未经米筐科技授权,任何个人不得出于任何商业目的使用本软件(包括但不限于向第三方提供、销售、出租、出借、转让本软件、
13
#         本软件的衍生产品、引用或借鉴了本软件功能或源代码的产品或服务),任何法人或其他组织不得出于任何目的使用本软件,
14
#         否则米筐科技有权追究相应的知识产权侵权责任。
15
#         在此前提下,对本软件的使用同样需要遵守 Apache 2.0 许可,Apache 2.0 许可与本许可冲突之处,以本许可为准。
16
#         详细的授权流程,请联系 public@ricequant.com 获取。
17

18
import os
8✔
19
from typing import List, Mapping, Tuple, Sequence, Optional
8✔
20
from collections import ChainMap
8✔
21

22
import pandas as pd
8✔
23
from matplotlib.axes import Axes
8✔
24
from matplotlib.figure import Figure
8✔
25
from matplotlib import gridspec, ticker, image as mpimg, pyplot
8✔
26

27
import rqalpha
8✔
28
from rqalpha.const import POSITION_EFFECT
8✔
29
from rqalpha.utils.logger import system_log
8✔
30
from .utils import IndicatorInfo, LineInfo, max_dd as _max_dd, SpotInfo, max_ddd as _max_ddd
8✔
31
from .utils import weekly_returns, trading_dates_index
8✔
32
from .consts import PlotTemplate, DefaultPlot
8✔
33
from .consts import IMG_WIDTH, INDICATOR_AREA_HEIGHT, PLOT_AREA_HEIGHT, USER_PLOT_AREA_HEIGHT, PLOT_TITLE_HEIGHT
8✔
34
from .consts import LABEL_FONT_SIZE, BLACK, SUPPORT_CHINESE, TITLE_FONT_SIZE
8✔
35
from .consts import MAX_DD, MAX_DDD, OPEN_POINT, CLOSE_POINT
8✔
36
from .consts import LINE_BENCHMARK, LINE_STRATEGY, LINE_WEEKLY_BENCHMARK, LINE_WEEKLY, LINE_EXCESS
8✔
37

38

39
class SubPlot:
8✔
40
    height: int
7✔
41
    right_pad: Optional[int] = None
8✔
42

43
    def plot(self, ax: Axes):
8✔
44
        raise NotImplementedError
×
45

46

47
class IndicatorArea(SubPlot):
8✔
48
    height: int = INDICATOR_AREA_HEIGHT
8✔
49
    right_pad = -1
8✔
50
    X_PADDING = 0.02
8✔
51

52
    def __init__(
8✔
53
            self, indicators: List[List[IndicatorInfo]], indicator_values: Mapping[str, float],
54
            plot_template: PlotTemplate, strategy_name=None
55
    ):
56
        self._indicators = indicators
×
57
        self._values = indicator_values
×
58
        self._template = plot_template
×
59
        self._strategy_name = strategy_name
×
60

61
    def _iter_layout(self, indicators):
8✔
NEW
62
        available_width = 1 - 2 * self.X_PADDING
×
NEW
63
        column_count = max(len(row) for row in self._indicators)
×
NEW
64
        column_width = available_width / column_count if column_count else 0
×
NEW
65
        for index, indicator in enumerate(indicators):
×
NEW
66
            yield self.X_PADDING + column_width * (index + 0.5), indicator
×
67

68
    @staticmethod
8✔
69
    def _label_font_size(label: str) -> int:
8✔
NEW
70
        if len(label) >= 30:
×
NEW
71
            return 7
×
NEW
72
        if len(label) >= 22:
×
NEW
73
            return 8
×
NEW
74
        if len(label) >= 16:
×
NEW
75
            return 9
×
NEW
76
        return LABEL_FONT_SIZE
×
77

78
    @staticmethod
8✔
79
    def _value_font_size(value: str, base_size: int) -> int:
8✔
NEW
80
        longest_line = max((len(line) for line in value.splitlines()), default=0)
×
NEW
81
        if longest_line >= 24:
×
NEW
82
            return min(base_size, 5)
×
NEW
83
        if longest_line >= 18:
×
NEW
84
            return min(base_size, 6)
×
NEW
85
        return base_size
×
86

87
    def plot(self, ax: Axes):
8✔
88
        ax.axis("off")
×
89
        for lineno, indicators in enumerate(self._indicators[::-1]):  # lineno: 自下而上的行号
×
NEW
90
            for x, i in self._iter_layout(indicators):
×
91
                y_value = lineno * (self._template.INDICATOR_VALUE_HEIGHT + self._template.INDICATOR_LABEL_HEIGHT)
×
92
                y_label = y_value + self._template.INDICATOR_LABEL_HEIGHT
×
93
                try:
×
94
                    value = i.formatter.format(self._values[i.key])
×
95
                except KeyError:
×
96
                    value = "nan"
×
NEW
97
                ax.text(x, y_label, i.label, color=i.color, fontsize=self._label_font_size(i.label), ha="center")
×
NEW
98
                ax.text(x, y_value, value, color=BLACK, fontsize=self._value_font_size(value, i.value_font_size), ha="center")
×
99
        if self._strategy_name:
×
100
            p = TitlePlot(self._strategy_name, len(self._indicators), self._template)
×
101
            p.plot(ax)
×
102
        
103

104
class ReturnPlot(SubPlot):
8✔
105
    height: int = PLOT_AREA_HEIGHT
8✔
106

107
    def __init__(
8✔
108
            self,
109
            returns,
110
            lines: List[Tuple[pd.Series, LineInfo]],
111
            spots_on_returns: List[Tuple[Sequence[int], SpotInfo]]
112
    ):
113
        self._returns = returns
×
114
        self._lines = lines
×
115
        self._spots_on_returns = spots_on_returns
×
116

117
    @classmethod
8✔
118
    def _plot_line(cls, ax, returns, info: LineInfo):
8✔
119
        if returns is not None:
×
120
            ax.plot(returns, label=info.label, alpha=info.alpha, linewidth=info.linewidth, color=info.color)
×
121

122
    def _plot_spots_on_returns(self, ax, positions: Sequence[int], info: SpotInfo):
8✔
123
        ax.plot(
×
124
            self._returns.index[positions], self._returns[positions],
125
            info.marker, color=info.color, markersize=info.markersize, alpha=info.alpha, label=info.label
126
        )
127

128
    def plot(self, ax: Axes):
8✔
129
        ax.get_xaxis().set_minor_locator(ticker.AutoMinorLocator())
×
130
        ax.get_yaxis().set_minor_locator(ticker.AutoMinorLocator())
×
131
        ax.grid(visible=True, which='minor', linewidth=.2)
×
132
        ax.grid(visible=True, which='major', linewidth=1)
×
133
        ax.patch.set_alpha(0.6)
×
134

135
        # plot lines
136
        for returns, info in self._lines:
×
137
            self._plot_line(ax, returns, info)
×
138
        # plot MaxDD/MaxDDD
139
        for positions, info in self._spots_on_returns:
×
140
            self._plot_spots_on_returns(ax, positions, info)
×
141

142
        # place legend
143
        pyplot.legend(loc="best").get_frame().set_alpha(0.5)
×
144
        # manipulate axis
145
        ax.set_yticks(ax.get_yticks())  # make matplotlib happy
×
146
        ax.set_yticklabels(['{:3.2f}%'.format(x * 100) for x in ax.get_yticks()])
×
147

148

149
class UserPlot(SubPlot):
8✔
150
    height: int = USER_PLOT_AREA_HEIGHT
8✔
151

152
    def __init__(self, plots_df):
8✔
153
        self._df = plots_df
×
154

155
    def plot(self, ax: Axes):
8✔
156
        ax.patch.set_alpha(0.6)
×
157
        for column in self._df.columns:
×
158
            ax.plot(self._df[column], label=column)
×
159
        pyplot.legend(loc="best").get_frame().set_alpha(0.5)
×
160

161

162
class TitlePlot(SubPlot):
8✔
163
    height: int = PLOT_TITLE_HEIGHT
8✔
164

165
    def __init__(self, strategy_name, indicator_area_rows, plot_template: PlotTemplate):
8✔
166
        self._strategy_name = strategy_name
×
167
        self._indicator_area_rows = indicator_area_rows
×
168
        self._template = plot_template
×
169

170
    def plot(self, ax:Axes):
8✔
171
        x = 0.57  # title 为整图居中,而非子图居中
×
172
        y = (self._template.INDICATOR_LABEL_HEIGHT + self._template.INDICATOR_VALUE_HEIGHT) * self._indicator_area_rows + 0.1
×
173
        ax.text(x, y, self._strategy_name, ha='center', va='bottom', color=BLACK, fontsize=TITLE_FONT_SIZE)
×
174

175
class WaterMark:
8✔
176
    def __init__(self, img_width, img_height, strategy_name):
8✔
177
        logo_file = os.path.join(
×
178
            os.path.dirname(os.path.realpath(rqalpha.__file__)),
179
            "resource", 'ricequant-logo.png')
180
        self.img_width = img_width
×
181
        self.img_height = img_height
×
182
        self.logo_img = mpimg.imread(logo_file)
×
183
        self.dpi = self.logo_img.shape[1] / img_width * 1.1
×
184

185
    def plot(self, fig: Figure):
8✔
186
        fig.figimage(
×
187
            self.logo_img, 
188
            xo = (self.img_width * self.dpi - self.logo_img.shape[1]) / 2,
189
            yo = (PLOT_AREA_HEIGHT * self.dpi - self.logo_img.shape[0]) / 2, 
190
            alpha=0.4
191
            )
192

193

194
def _compact_index_range(index_range):
8✔
NEW
195
    start = index_range.start_date
×
NEW
196
    end = index_range.end_date
×
NEW
197
    if start.year == end.year:
×
NEW
198
        span = "{}~{}, {}d".format(start.strftime("%m-%d"), end.strftime("%m-%d"), (end - start).days)
×
199
    else:
NEW
200
        span = "{}~{}, {}d".format(start, end, (end - start).days)
×
NEW
201
    return span
×
202

203

204
def _plot(title: str, sub_plots: List[SubPlot], strategy_name):
8✔
205
    img_height = sum(s.height for s in sub_plots)
×
206
    water_mark = WaterMark(IMG_WIDTH, img_height, strategy_name)
×
207
    fig = pyplot.figure(title, figsize=(IMG_WIDTH, img_height), dpi=water_mark.dpi, clear=True)
×
208
    water_mark.plot(fig)
×
209

210
    gs = gridspec.GridSpec(img_height, 8, figure=fig)
×
211
    last_height = 0
×
212
    for p in sub_plots:
×
213
        p.plot(pyplot.subplot(gs[last_height:last_height + p.height, :p.right_pad]))
×
214
        last_height += p.height
×
215

216
    pyplot.tight_layout()
×
217
    return fig
×
218

219

220
def plot_result(
8✔
221
        result_dict, show=True, save=None, weekly_indicators: bool = False, open_close_points: bool = False,
222
        plot_template_cls=DefaultPlot, strategy_name=None
223
):
224
    summary = result_dict["summary"]
×
225
    portfolio = result_dict["portfolio"]
×
226

227
    return_lines: List[Tuple[pd.Series, LineInfo]] = [(portfolio.unit_net_value - 1, LINE_STRATEGY)]
×
228
    if "benchmark_portfolio" in result_dict:
×
229
        benchmark_portfolio = result_dict["benchmark_portfolio"]
×
230
        plot_template = plot_template_cls(portfolio.unit_net_value, benchmark_portfolio.unit_net_value)
×
231
        ex_returns = plot_template.geometric_excess_returns
×
232
        ex_max_dd_ddd = "MaxDD {}\nMaxDDD {}".format(
×
233
            _compact_index_range(_max_dd(ex_returns + 1, portfolio.index)),
234
            _compact_index_range(_max_ddd(ex_returns + 1, portfolio.index)),
235
        )
236
        indicators = plot_template.INDICATORS + plot_template.EXCESS_INDICATORS
×
237

238
        # 在图例中输出基准信息
239
        _b_str = summary["benchmark_symbol"] if SUPPORT_CHINESE else summary["benchmark"]
×
240
        _INFO = LineInfo(
×
241
            LINE_BENCHMARK.label + "({})".format(_b_str), LINE_BENCHMARK.color,
242
            LINE_BENCHMARK.alpha, LINE_BENCHMARK.linewidth
243
        )
244

245
        return_lines.extend([
×
246
            (benchmark_portfolio.unit_net_value - 1, _INFO),
247
            (ex_returns, LINE_EXCESS),
248
        ])
249
        if weekly_indicators:
×
250
            return_lines.append((weekly_returns(benchmark_portfolio), LINE_WEEKLY_BENCHMARK))
×
251
    else:
252
        ex_max_dd_ddd = "nan"
×
253
        plot_template = plot_template_cls(portfolio.unit_net_value, None)
×
254
        indicators = plot_template.INDICATORS
×
255
    if weekly_indicators:
×
256
        return_lines.append((weekly_returns(portfolio), LINE_WEEKLY))
×
257
        indicators.extend(plot_template.WEEKLY_INDICATORS)
×
258
    max_dd = _max_dd(portfolio.unit_net_value.values, portfolio.index)
×
259
    max_ddd = summary["max_drawdown_duration"]
×
260
    spots_on_returns: List[Tuple[Sequence[int], SpotInfo]] = [
×
261
        ([max_dd.start, max_dd.end], MAX_DD),
262
        ([max_ddd.start, max_ddd.end], MAX_DDD)
263
    ]
264
    if open_close_points and not result_dict["trades"].empty:
×
265
        trades: pd.DataFrame = result_dict["trades"]
×
266
        spots_on_returns.append((trading_dates_index(trades, POSITION_EFFECT.CLOSE, portfolio.index), CLOSE_POINT))
×
267
        spots_on_returns.append((trading_dates_index(trades, POSITION_EFFECT.OPEN, portfolio.index), OPEN_POINT))
×
268

269
    sub_plots = [IndicatorArea(indicators, ChainMap(summary, {
×
270
        "max_dd_ddd": "MaxDD {}\nMaxDDD {}".format(
271
            _compact_index_range(max_dd), _compact_index_range(max_ddd)
272
        ),
273
        "excess_max_dd_ddd": ex_max_dd_ddd,
274
    }), plot_template, strategy_name), ReturnPlot(
275
        portfolio.unit_net_value - 1, return_lines, spots_on_returns
276
    )]
277
    if "plots" in result_dict:
×
278
        sub_plots.append(UserPlot(result_dict["plots"]))
×
279
    
280
    if strategy_name:
×
281
        for p in sub_plots:
×
282
            if (isinstance(p, IndicatorArea)): p.height += PLOT_TITLE_HEIGHT
×
283

284
    _plot(summary["strategy_file"], sub_plots, strategy_name)
×
285

286
    system_log.debug(f"Matplotlib backend: {pyplot.get_backend()}")
×
287
    
288
    if save:
×
289
        file_path = save
×
290
        if os.path.isdir(save):
×
291
            file_path = os.path.join(save, "{}.png".format(summary["strategy_name"]))
×
292
        pyplot.savefig(file_path, bbox_inches='tight')
×
293

294
    if show:
×
295
        pyplot.show()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc