• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 10029876295

21 Jul 2024 04:45PM UTC coverage: 94.629%. Remained the same
10029876295

push

github

web-flow
Merge pull request #1033 from murrayrm/ctrlplot_refactor-27Jun2024

Move ctrlplot code prior to upcoming PR

8915 of 9421 relevant lines covered (94.63%)

8.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.08
control/ctrlplot.py
1
# ctrlplot.py - utility functions for plotting
2
# Richard M. Murray, 14 Jun 2024
3
#
4
# Collection of functions that are used by various plotting functions.
5

6
from os.path import commonprefix
9✔
7

8
import matplotlib as mpl
9✔
9
import matplotlib.pyplot as plt
9✔
10
import numpy as np
9✔
11

12
from . import config
9✔
13

14
__all__ = ['suptitle', 'get_plot_axes']
9✔
15

16
#
17
# Style parameters
18
#
19

20
_ctrlplot_rcParams = mpl.rcParams.copy()
9✔
21
_ctrlplot_rcParams.update({
9✔
22
    'axes.labelsize': 'small',
23
    'axes.titlesize': 'small',
24
    'figure.titlesize': 'medium',
25
    'legend.fontsize': 'x-small',
26
    'xtick.labelsize': 'small',
27
    'ytick.labelsize': 'small',
28
})
29

30

31
#
32
# User functions
33
#
34
# The functions below can be used by users to modify ctrl plots or get
35
# information about them.
36
#
37

38

39
def suptitle(
9✔
40
        title, fig=None, frame='axes', **kwargs):
41
    """Add a centered title to a figure.
42

43
    This is a wrapper for the matplotlib `suptitle` function, but by
44
    setting ``frame`` to 'axes' (default) then the title is centered on the
45
    midpoint of the axes in the figure, rather than the center of the
46
    figure.  This usually looks better (particularly with multi-panel
47
    plots), though it takes longer to render.
48

49
    Parameters
50
    ----------
51
    title : str
52
        Title text.
53
    fig : Figure, optional
54
        Matplotlib figure.  Defaults to current figure.
55
    frame : str, optional
56
        Coordinate frame to use for centering: 'axes' (default) or 'figure'.
57
    **kwargs : :func:`matplotlib.pyplot.suptitle` keywords, optional
58
        Additional keywords (passed to matplotlib).
59

60
    """
61
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
62

63
    if fig is None:
9✔
64
        fig = plt.gcf()
9✔
65

66
    if frame == 'figure':
9✔
67
        with plt.rc_context(rcParams):
9✔
68
            fig.suptitle(title, **kwargs)
9✔
69

70
    elif frame == 'axes':
9✔
71
        # TODO: move common plotting params to 'ctrlplot'
72
        with plt.rc_context(rcParams):
9✔
73
            plt.tight_layout()          # Put the figure into proper layout
9✔
74
            xc, _ = _find_axes_center(fig, fig.get_axes())
9✔
75

76
            fig.suptitle(title, x=xc, **kwargs)
9✔
77
            plt.tight_layout()          # Update the layout
9✔
78

79
    else:
80
        raise ValueError(f"unknown frame '{frame}'")
9✔
81

82

83
# Create vectorized function to find axes from lines
84
def get_plot_axes(line_array):
9✔
85
    """Get a list of axes from an array of lines.
86

87
    This function can be used to return the set of axes corresponding
88
    to the line array that is returned by `time_response_plot`.  This
89
    is useful for generating an axes array that can be passed to
90
    subsequent plotting calls.
91

92
    Parameters
93
    ----------
94
    line_array : array of list of Line2D
95
        A 2D array with elements corresponding to a list of lines appearing
96
        in an axes, matching the return type of a time response data plot.
97

98
    Returns
99
    -------
100
    axes_array : array of list of Axes
101
        A 2D array with elements corresponding to the Axes assocated with
102
        the lines in `line_array`.
103

104
    Notes
105
    -----
106
    Only the first element of each array entry is used to determine the axes.
107

108
    """
109
    _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
110
    return _get_axes(line_array)
9✔
111

112
#
113
# Utility functions
114
#
115
# These functions are used by plotting routines to provide a consistent way
116
# of processing and displaying information.
117
#
118

119

120
def _process_ax_keyword(
9✔
121
        axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False):
122
    """Utility function to process ax keyword to plotting commands.
123

124
    This function processes the `ax` keyword to plotting commands.  If no
125
    ax keyword is passed, the current figure is checked to see if it has
126
    the correct shape.  If the shape matches the desired shape, then the
127
    current figure and axes are returned.  Otherwise a new figure is
128
    created with axes of the desired shape.
129

130
    Legacy behavior: some of the older plotting commands use a axes label
131
    to identify the proper axes for plotting.  This behavior is supported
132
    through the use of the label keyword, but will only work if shape ==
133
    (1, 1) and squeeze == True.
134

135
    """
136
    if axs is None:
9✔
137
        fig = plt.gcf()         # get current figure (or create new one)
9✔
138
        axs = fig.get_axes()
9✔
139

140
        # Check to see if axes are the right shape; if not, create new figure
141
        # Note: can't actually check the shape, just the total number of axes
142
        if len(axs) != np.prod(shape):
9✔
143
            with plt.rc_context(rcParams):
9✔
144
                if len(axs) != 0:
9✔
145
                    # Create a new figure
146
                    fig, axs = plt.subplots(*shape, squeeze=False)
9✔
147
                else:
148
                    # Create new axes on (empty) figure
149
                    axs = fig.subplots(*shape, squeeze=False)
9✔
150
            fig.set_layout_engine('tight')
9✔
151
            fig.align_labels()
9✔
152
        else:
153
            # Use the existing axes, properly reshaped
154
            axs = np.asarray(axs).reshape(*shape)
9✔
155

156
            if clear_text:
9✔
157
                # Clear out any old text from the current figure
158
                for text in fig.texts:
9✔
159
                    text.set_visible(False)     # turn off the text
9✔
160
                    del text                    # get rid of it completely
9✔
161
    else:
162
        try:
9✔
163
            axs = np.asarray(axs).reshape(shape)
9✔
164
        except ValueError:
9✔
165
            raise ValueError(
9✔
166
                "specified axes are not the right shape; "
167
                f"got {axs.shape} but expecting {shape}")
168
        fig = axs[0, 0].figure
9✔
169

170
    # Process the squeeze keyword
171
    if squeeze and shape == (1, 1):
9✔
172
        axs = axs[0, 0]         # Just return the single axes object
9✔
173
    elif squeeze:
9✔
174
        axs = axs.squeeze()
×
175

176
    return fig, axs
9✔
177

178

179
# Turn label keyword into array indexed by trace, output, input
180
# TODO: move to ctrlutil.py and update parameter names to reflect general use
181
def _process_line_labels(label, ntraces, ninputs=0, noutputs=0):
9✔
182
    if label is None:
9✔
183
        return None
9✔
184

185
    if isinstance(label, str):
9✔
186
        label = [label] * ntraces          # single label for all traces
9✔
187

188
    # Convert to an ndarray, if not done aleady
189
    try:
9✔
190
        line_labels = np.asarray(label)
9✔
191
    except ValueError:
×
192
        raise ValueError("label must be a string or array_like")
×
193

194
    # Turn the data into a 3D array of appropriate shape
195
    # TODO: allow more sophisticated broadcasting (and error checking)
196
    try:
9✔
197
        if ninputs > 0 and noutputs > 0:
9✔
198
            if line_labels.ndim == 1 and line_labels.size == ntraces:
9✔
199
                line_labels = line_labels.reshape(ntraces, 1, 1)
9✔
200
                line_labels = np.broadcast_to(
9✔
201
                    line_labels, (ntraces, ninputs, noutputs))
202
            else:
203
                line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
9✔
204
    except ValueError:
9✔
205
        if line_labels.shape[0] != ntraces:
9✔
206
            raise ValueError("number of labels must match number of traces")
9✔
207
        else:
208
            raise ValueError("labels must be given for each input/output pair")
×
209

210
    return line_labels
9✔
211

212

213
# Get labels for all lines in an axes
214
def _get_line_labels(ax, use_color=True):
9✔
215
    labels, lines = [], []
9✔
216
    last_color, counter = None, 0       # label unknown systems
9✔
217
    for i, line in enumerate(ax.get_lines()):
9✔
218
        label = line.get_label()
9✔
219
        if use_color and label.startswith("Unknown"):
9✔
220
            label = f"Unknown-{counter}"
×
221
            if last_color is None:
×
222
                last_color = line.get_color()
×
223
            elif last_color != line.get_color():
×
224
                counter += 1
×
225
                last_color = line.get_color()
×
226
        elif label[0] == '_':
9✔
227
            continue
9✔
228

229
        if label not in labels:
9✔
230
            lines.append(line)
9✔
231
            labels.append(label)
9✔
232

233
    return lines, labels
9✔
234

235

236
# Utility function to make legend labels
237
def _make_legend_labels(labels, ignore_common=False):
9✔
238

239
    # Look for a common prefix (up to a space)
240
    common_prefix = commonprefix(labels)
9✔
241
    last_space = common_prefix.rfind(', ')
9✔
242
    if last_space < 0 or ignore_common:
9✔
243
        common_prefix = ''
9✔
244
    elif last_space > 0:
9✔
245
        common_prefix = common_prefix[:last_space]
9✔
246
    prefix_len = len(common_prefix)
9✔
247

248
    # Look for a common suffix (up to a space)
249
    common_suffix = commonprefix(
9✔
250
        [label[::-1] for label in labels])[::-1]
251
    suffix_len = len(common_suffix)
9✔
252
    # Only chop things off after a comma or space
253
    while suffix_len > 0 and common_suffix[-suffix_len] != ',':
9✔
254
        suffix_len -= 1
9✔
255

256
    # Strip the labels of common information
257
    if suffix_len > 0 and not ignore_common:
9✔
258
        labels = [label[prefix_len:-suffix_len] for label in labels]
9✔
259
    else:
260
        labels = [label[prefix_len:] for label in labels]
9✔
261

262
    return labels
9✔
263

264

265
def _update_suptitle(fig, title, rcParams=None, frame='axes'):
9✔
266
    if fig is not None and isinstance(title, str):
9✔
267
        # Get the current title, if it exists
268
        old_title = None if fig._suptitle is None else fig._suptitle._text
9✔
269

270
        if old_title is not None:
9✔
271
            # Find the common part of the titles
272
            common_prefix = commonprefix([old_title, title])
9✔
273

274
            # Back up to the last space
275
            last_space = common_prefix.rfind(' ')
9✔
276
            if last_space > 0:
9✔
277
                common_prefix = common_prefix[:last_space]
9✔
278
            common_len = len(common_prefix)
9✔
279

280
            # Add the new part of the title (usually the system name)
281
            if old_title[common_len:] != title[common_len:]:
9✔
282
                separator = ',' if len(common_prefix) > 0 else ';'
9✔
283
                title = old_title + separator + title[common_len:]
9✔
284

285
        # Add the title
286
        suptitle(title, fig=fig, rcParams=rcParams, frame=frame)
9✔
287

288

289
def _find_axes_center(fig, axs):
9✔
290
    """Find the midpoint between axes in display coordinates.
291

292
    This function finds the middle of a plot as defined by a set of axes.
293

294
    """
295
    inv_transform = fig.transFigure.inverted()
9✔
296
    xlim = ylim = [1, 0]
9✔
297
    for ax in axs:
9✔
298
        ll = inv_transform.transform(ax.transAxes.transform((0, 0)))
9✔
299
        ur = inv_transform.transform(ax.transAxes.transform((1, 1)))
9✔
300

301
        xlim = [min(ll[0], xlim[0]), max(ur[0], xlim[1])]
9✔
302
        ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]
9✔
303

304
    return (np.sum(xlim)/2, np.sum(ylim)/2)
9✔
305

306

307
# Internal function to add arrows to a curve
308
def _add_arrows_to_line2D(
9✔
309
        axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
310
        arrowstyle='-|>', arrowsize=1, dir=1):
311
    """
312
    Add arrows to a matplotlib.lines.Line2D at selected locations.
313

314
    Parameters:
315
    -----------
316
    axes: Axes object as returned by axes command (or gca)
317
    line: Line2D object as returned by plot command
318
    arrow_locs: list of locations where to insert arrows, % of total length
319
    arrowstyle: style of the arrow
320
    arrowsize: size of the arrow
321

322
    Returns:
323
    --------
324
    arrows: list of arrows
325

326
    Based on https://stackoverflow.com/questions/26911898/
327

328
    """
329
    # Get the coordinates of the line, in plot coordinates
330
    if not isinstance(line, mpl.lines.Line2D):
9✔
331
        raise ValueError("expected a matplotlib.lines.Line2D object")
×
332
    x, y = line.get_xdata(), line.get_ydata()
9✔
333

334
    # Determine the arrow properties
335
    arrow_kw = {"arrowstyle": arrowstyle}
9✔
336

337
    color = line.get_color()
9✔
338
    use_multicolor_lines = isinstance(color, np.ndarray)
9✔
339
    if use_multicolor_lines:
9✔
340
        raise NotImplementedError("multicolor lines not supported")
341
    else:
342
        arrow_kw['color'] = color
9✔
343

344
    linewidth = line.get_linewidth()
9✔
345
    if isinstance(linewidth, np.ndarray):
9✔
346
        raise NotImplementedError("multiwidth lines not supported")
347
    else:
348
        arrow_kw['linewidth'] = linewidth
9✔
349

350
    # Figure out the size of the axes (length of diagonal)
351
    xlim, ylim = axes.get_xlim(), axes.get_ylim()
9✔
352
    ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
9✔
353
    diag = np.linalg.norm(ul - lr)
9✔
354

355
    # Compute the arc length along the curve
356
    s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
9✔
357

358
    # Truncate the number of arrows if the curve is short
359
    # TODO: figure out a smarter way to do this
360
    frac = min(s[-1] / diag, 1)
9✔
361
    if len(arrow_locs) and frac < 0.05:
9✔
362
        arrow_locs = []         # too short; no arrows at all
9✔
363
    elif len(arrow_locs) and frac < 0.2:
9✔
364
        arrow_locs = [0.5]      # single arrow in the middle
9✔
365

366
    # Plot the arrows (and return list if patches)
367
    arrows = []
9✔
368
    for loc in arrow_locs:
9✔
369
        n = np.searchsorted(s, s[-1] * loc)
9✔
370

371
        if dir == 1 and n == 0:
9✔
372
            # Move the arrow forward by one if it is at start of a segment
373
            n = 1
9✔
374

375
        # Place the head of the arrow at the desired location
376
        arrow_head = [x[n], y[n]]
9✔
377
        arrow_tail = [x[n - dir], y[n - dir]]
9✔
378

379
        p = mpl.patches.FancyArrowPatch(
9✔
380
            arrow_tail, arrow_head, transform=axes.transData, lw=0,
381
            **arrow_kw)
382
        axes.add_patch(p)
9✔
383
        arrows.append(p)
9✔
384
    return arrows
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc