• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 16081041622

04 Jul 2025 09:18PM UTC coverage: 94.733% (-0.01%) from 94.745%
16081041622

push

github

web-flow
Merge pull request #1155 from murrayrm/fix_nyquist_rescaling-24Mar2025

Update Nyquist rescaling + other improvements

9946 of 10499 relevant lines covered (94.73%)

8.28 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.97
control/ctrlplot.py
1
# ctrlplot.py - utility functions for plotting
2
# RMM, 14 Jun 2024
3
#
4

5
"""Utility functions for plotting.
6

7
This module contains a collection of functions that are used by
8
various plotting functions.
9

10
"""
11

12
# Code pattern for control system plotting functions:
13
#
14
# def name_plot(sysdata, *fmt, plot=None, **kwargs):
15
#     # Process keywords and set defaults
16
#     ax = kwargs.pop('ax', None)
17
#     color = kwargs.pop('color', None)
18
#     label = kwargs.pop('label', None)
19
#     rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
20
#
21
#     # Make sure all keyword arguments were processed (if not checked later)
22
#     if kwargs:
23
#         raise TypeError("unrecognized keywords: ", str(kwargs))
24
#
25
#     # Process the data (including generating responses for systems)
26
#     sysdata = list(sysdata)
27
#     if any([isinstance(sys, InputOutputSystem) for sys in sysdata]):
28
#         data = name_response(sysdata)
29
#     nrows = max([data.noutputs for data in sysdata])
30
#     ncols = max([data.ninputs for data in sysdata])
31
#
32
#     # Legacy processing of plot keyword
33
#     if plot is False:
34
#         return data.x, data.y
35
#
36
#     # Figure out the shape of the plot and find/create axes
37
#     fig, ax_array = _process_ax_keyword(ax, (nrows, ncols), rcParams)
38
#     legend_loc, legend_map, show_legend = _process_legend_keywords(
39
#         kwargs, (nrows, ncols), 'center right')
40
#
41
#     # Customize axes (curvilinear grids, shared axes, etc)
42
#
43
#     # Plot the data
44
#     lines = np.empty(ax_array.shape, dtype=object)
45
#     for i in range(ax_array.shape[0]):
46
#         for j in range(ax_array.shape[1]):
47
#             lines[i, j] = []
48
#     line_labels = _process_line_labels(label, ntraces, nrows, ncols)
49
#     color_offset, color_cycle = _get_color_offset(ax)
50
#     for i, j in itertools.product(range(nrows), range(ncols)):
51
#         ax = ax_array[i, j]
52
#         for k in range(ntraces):
53
#             if color is None:
54
#                 color = _get_color(
55
#                     color, fmt=fmt, offset=k, color_cycle=color_cycle)
56
#             label = line_labels[k, i, j]
57
#             lines[i, j] += ax.plot(data.x, data.y, color=color, label=label)
58
#
59
#     # Customize and label the axes
60
#     for i, j in itertools.product(range(nrows), range(ncols)):
61
#         ax_array[i, j].set_xlabel("x label")
62
#         ax_array[i, j].set_ylabel("y label")
63
#
64
#     # Create legends
65
#     if show_legend != False:
66
#         legend_array = np.full(ax_array.shape, None, dtype=object)
67
#         for i, j in itertools.product(range(nrows), range(ncols)):
68
#             if legend_map[i, j] is not None:
69
#                 lines = ax_array[i, j].get_lines()
70
#                 labels = _make_legend_labels(lines)
71
#                 if len(labels) > 1:
72
#                     legend_array[i, j] = ax.legend(
73
#                         lines, labels, loc=legend_map[i, j])
74
#     else:
75
#         legend_array = None
76
#
77
#     # Update the plot title (only if ax was not given)
78
#     sysnames = [response.sysname for response in data]
79
#     if ax is None and title is None:
80
#         title = "Name plot for " + ", ".join(sysnames)
81
#         _update_plot_title(title, fig, rcParams=rcParams)
82
#     elif ax == None:
83
#         _update_plot_title(title, fig, rcParams=rcParams, use_existing=False)
84
#
85
#     # Legacy processing of plot keyword
86
#     if plot is True:
87
#         return data
88
#
89
#     return ControlPlot(lines, ax_array, fig, legend=legend_map)
90

91
import itertools
9✔
92
import warnings
9✔
93
from os.path import commonprefix
9✔
94

95
import matplotlib as mpl
9✔
96
import matplotlib.pyplot as plt
9✔
97
import numpy as np
9✔
98

99
from . import config
9✔
100

101
__all__ = [
9✔
102
    'ControlPlot', 'suptitle', 'get_plot_axes', 'pole_zero_subplots',
103
    'rcParams', 'reset_rcParams']
104

105
#
106
# Style parameters
107
#
108

109
rcParams_default = {
9✔
110
    'axes.labelsize': 'small',
111
    'axes.titlesize': 'small',
112
    'figure.titlesize': 'medium',
113
    'legend.fontsize': 'x-small',
114
    'xtick.labelsize': 'small',
115
    'ytick.labelsize': 'small',
116
}
117
_ctrlplot_rcParams = rcParams_default.copy()    # provide access inside module
9✔
118
rcParams = _ctrlplot_rcParams                   # provide access outside module
9✔
119

120
_ctrlplot_defaults = {'ctrlplot.rcParams': _ctrlplot_rcParams}
9✔
121

122

123
#
124
# Control figure
125
#
126

127
class ControlPlot():
9✔
128
    """Return class for control platting functions.
129

130
    This class is used as the return type for control plotting functions.
131
    It contains the information required to access portions of the plot
132
    that the user might want to adjust, as well as providing methods to
133
    modify some of the properties of the plot.
134

135
    A control figure consists of a `matplotlib.figure.Figure` with
136
    an array of `matplotlib.axes.Axes`.  Each axes in the figure has
137
    a number of lines that represent the data for the plot.  There may also
138
    be a legend present in one or more of the axes.
139

140
    Parameters
141
    ----------
142
    lines : array of list of `matplotlib.lines.Line2D`
143
        Array of Line2D objects for each line in the plot.  Generally, the
144
        shape of the array matches the subplots shape and the value of the
145
        array is a list of Line2D objects in that subplot.  Some plotting
146
        functions will return variants of this structure, as described in
147
        the individual documentation for the functions.
148
    axes : 2D array of `matplotlib.axes.Axes`
149
        Array of Axes objects for each subplot in the plot.
150
    figure : `matplotlib.figure.Figure`
151
        Figure on which the Axes are drawn.
152
    legend : `matplotlib.legend.Legend` (instance or ndarray)
153
        Legend object(s) for the plot.  If more than one legend is
154
        included, this will be an array with each entry being either None
155
        (for no legend) or a legend object.
156

157
    """
158
    def __init__(self, lines, axes=None, figure=None, legend=None):
9✔
159
        self.lines = lines
9✔
160
        if axes is None:
9✔
161
            _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
162
            axes = _get_axes(lines)
9✔
163
        self.axes = np.atleast_2d(axes)
9✔
164
        if figure is None:
9✔
165
            figure = self.axes[0, 0].figure
9✔
166
        self.figure = figure
9✔
167
        self.legend = legend
9✔
168

169
    # Implement methods and properties to allow legacy interface (np.array)
170
    __iter__ = lambda self: self.lines
9✔
171
    __len__ = lambda self: len(self.lines)
9✔
172
    def __getitem__(self, item):
9✔
173
        warnings.warn(
9✔
174
            "return of Line2D objects from plot function is deprecated in "
175
            "favor of ControlPlot; use out.lines to access Line2D objects",
176
            category=FutureWarning)
177
        return self.lines[item]
9✔
178
    def __setitem__(self, item, val):
9✔
179
        self.lines[item] = val
×
180
    shape = property(lambda self: self.lines.shape, None)
9✔
181
    def reshape(self, *args):
9✔
182
        """Reshape lines array (legacy)."""
183
        return self.lines.reshape(*args)
9✔
184

185
    def set_plot_title(self, title, frame='axes'):
9✔
186
        """Set the title for a control plot.
187

188
        This is a wrapper for the matplotlib `suptitle` function, but by
189
        setting `frame` to 'axes' (default) then the title is centered on
190
        the midpoint of the axes in the figure, rather than the center of
191
        the figure.  This usually looks better (particularly with
192
        multi-panel plots), though it takes longer to render.
193

194
        Parameters
195
        ----------
196
        title : str
197
            Title text.
198
        fig : Figure, optional
199
            Matplotlib figure.  Defaults to current figure.
200
        frame : str, optional
201
            Coordinate frame for centering: 'axes' (default) or 'figure'.
202
        **kwargs : `matplotlib.pyplot.suptitle` keywords, optional
203
            Additional keywords (passed to matplotlib).
204

205
        """
206
        _update_plot_title(
9✔
207
            title, fig=self.figure, frame=frame, use_existing=False)
208

209
#
210
# User functions
211
#
212
# The functions below can be used by users to modify control plots or get
213
# information about them.
214
#
215

216
def suptitle(
9✔
217
        title, fig=None, frame='axes', **kwargs):
218
    """Add a centered title to a figure.
219

220
    .. deprecated:: 0.10.1
221
        Use `ControlPlot.set_plot_title`.
222

223
    """
224
    warnings.warn(
9✔
225
        "suptitle() is deprecated; use cplt.set_plot_title()", FutureWarning)
226
    _update_plot_title(
9✔
227
        title, fig=fig, frame=frame, use_existing=False, **kwargs)
228

229

230
# Create vectorized function to find axes from lines
231
def get_plot_axes(line_array):
9✔
232
    """Get a list of axes from an array of lines.
233

234
    .. deprecated:: 0.10.1
235
        This function will be removed in a future version of python-control.
236
        Use `cplt.axes` to obtain axes for an instance of `ControlPlot`.
237

238
    This function can be used to return the set of axes corresponding
239
    to the line array that is returned by `time_response_plot`.  This
240
    is useful for generating an axes array that can be passed to
241
    subsequent plotting calls.
242

243
    Parameters
244
    ----------
245
    line_array : array of list of `matplotlib.lines.Line2D`
246
        A 2D array with elements corresponding to a list of lines appearing
247
        in an axes, matching the return type of a time response data plot.
248

249
    Returns
250
    -------
251
    axes_array : array of list of `matplotlib.axes.Axes`
252
        A 2D array with elements corresponding to the Axes associated with
253
        the lines in `line_array`.
254

255
    Notes
256
    -----
257
    Only the first element of each array entry is used to determine the axes.
258

259
    """
260
    warnings.warn(
9✔
261
        "get_plot_axes() is deprecated; use cplt.axes()", FutureWarning)
262
    _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
263
    if isinstance(line_array, ControlPlot):
9✔
264
        return _get_axes(line_array.lines)
9✔
265
    else:
266
        return _get_axes(line_array)
9✔
267

268

269
def pole_zero_subplots(
9✔
270
        nrows, ncols, grid=None, dt=None, fig=None, scaling=None,
271
        rcParams=None):
272
    """Create axes for pole/zero plot.
273

274
    Parameters
275
    ----------
276
    nrows, ncols : int
277
        Number of rows and columns.
278
    grid : True, False, or 'empty', optional
279
        Grid style to use.  Can also be a list, in which case each subplot
280
        will have a different style (columns then rows).
281
    dt : timebase, option
282
        Timebase for each subplot (or a list of timebases).
283
    scaling : 'auto', 'equal', or None
284
        Scaling to apply to the subplots.
285
    fig : `matplotlib.figure.Figure`
286
        Figure to use for creating subplots.
287
    rcParams : dict
288
        Override the default parameters used for generating plots.
289
        Default is set by `config.defaults['ctrlplot.rcParams']`.
290

291
    Returns
292
    -------
293
    ax_array : ndarray
294
        2D array of axes.
295

296
    """
297
    from .grid import nogrid, sgrid, zgrid
9✔
298
    from .iosys import isctime
9✔
299

300
    if fig is None:
9✔
301
        fig = plt.gcf()
9✔
302
    rcParams = config._get_param('ctrlplot', 'rcParams', rcParams)
9✔
303

304
    if not isinstance(grid, list):
9✔
305
        grid = [grid] * nrows * ncols
9✔
306
    if not isinstance(dt, list):
9✔
307
        dt = [dt] * nrows * ncols
9✔
308

309
    ax_array = np.full((nrows, ncols), None)
9✔
310
    index = 0
9✔
311
    with plt.rc_context(rcParams):
9✔
312
        for row, col in itertools.product(range(nrows), range(ncols)):
9✔
313
            match grid[index], isctime(dt=dt[index]):
9✔
314
                case 'empty', _:        # empty grid
9✔
315
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
316

317
                case True, True:        # continuous-time grid
9✔
318
                    ax_array[row, col], _ = sgrid(
9✔
319
                        (nrows, ncols, index+1), scaling=scaling)
320

321
                case True, False:       # discrete-time grid
9✔
322
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
323
                    zgrid(ax=ax_array[row, col], scaling=scaling)
9✔
324

325
                case False | None, _:   # no grid (just stability boundaries)
9✔
326
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
327
                    nogrid(
9✔
328
                        ax=ax_array[row, col], dt=dt[index], scaling=scaling)
329
            index += 1
9✔
330
    return ax_array
9✔
331

332

333
def reset_rcParams():
9✔
334
    """Reset rcParams to default values for control plots."""
335
    _ctrlplot_rcParams.update(rcParams_default)
9✔
336

337

338
#
339
# Utility functions
340
#
341
# These functions are used by plotting routines to provide a consistent way
342
# of processing and displaying information.
343
#
344

345
def _process_ax_keyword(
9✔
346
        axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False,
347
        create_axes=True, sharex=False, sharey=False):
348
    """Process ax keyword to plotting commands.
349

350
    This function processes the `ax` keyword to plotting commands.  If no
351
    ax keyword is passed, the current figure is checked to see if it has
352
    the correct shape.  If the shape matches the desired shape, then the
353
    current figure and axes are returned.  Otherwise a new figure is
354
    created with axes of the desired shape.
355

356
    If `create_axes` is False and a new/empty figure is returned, then `axs`
357
    is an array of the proper shape but None for each element.  This allows
358
    the calling function to do the actual axis creation (needed for
359
    curvilinear grids that use the AxisArtist module).
360

361
    Legacy behavior: some of the older plotting commands use an axes label
362
    to identify the proper axes for plotting.  This behavior is supported
363
    through the use of the label keyword, but will only work if shape ==
364
    (1, 1) and squeeze == True.
365

366
    """
367
    if axs is None:
9✔
368
        fig = plt.gcf()         # get current figure (or create new one)
9✔
369
        axs = fig.get_axes()
9✔
370

371
        # Check to see if axes are the right shape; if not, create new figure
372
        # Note: can't actually check the shape, just the total number of axes
373
        if len(axs) != np.prod(shape):
9✔
374
            with plt.rc_context(rcParams):
9✔
375
                if len(axs) != 0 and create_axes:
9✔
376
                    # Create a new figure
377
                    fig, axs = plt.subplots(
9✔
378
                        *shape, sharex=sharex, sharey=sharey, squeeze=False)
379
                elif create_axes:
9✔
380
                    # Create new axes on (empty) figure
381
                    axs = fig.subplots(
9✔
382
                        *shape, sharex=sharex, sharey=sharey, squeeze=False)
383
                else:
384
                    # Create an empty array and let user create axes
385
                    axs = np.full(shape, None)
9✔
386
            if create_axes:     # if not creating axes, leave these to caller
9✔
387
                fig.set_layout_engine('tight')
9✔
388
                fig.align_labels()
9✔
389

390
        else:
391
            # Use the existing axes, properly reshaped
392
            axs = np.asarray(axs).reshape(*shape)
9✔
393

394
            if clear_text:
9✔
395
                # Clear out any old text from the current figure
396
                for text in fig.texts:
9✔
397
                    text.set_visible(False)     # turn off the text
9✔
398
                    del text                    # get rid of it completely
9✔
399
    else:
400
        axs = np.atleast_1d(axs)
9✔
401
        try:
9✔
402
            axs = axs.reshape(shape)
9✔
403
        except ValueError:
9✔
404
            raise ValueError(
9✔
405
                "specified axes are not the right shape; "
406
                f"got {axs.shape} but expecting {shape}")
407
        fig = axs[0, 0].figure
9✔
408

409
    # Process the squeeze keyword
410
    if squeeze and shape == (1, 1):
9✔
411
        axs = axs[0, 0]         # Just return the single axes object
9✔
412
    elif squeeze:
9✔
413
        axs = axs.squeeze()
×
414

415
    return fig, axs
9✔
416

417

418
# Turn label keyword into array indexed by trace, output, input
419
# TODO: move to ctrlutil.py and update parameter names to reflect general use
420
def _process_line_labels(label, ntraces=1, ninputs=0, noutputs=0):
9✔
421
    if label is None:
9✔
422
        return None
9✔
423

424
    if isinstance(label, str):
9✔
425
        label = [label] * ntraces          # single label for all traces
9✔
426

427
    # Convert to an ndarray, if not done already
428
    try:
9✔
429
        line_labels = np.asarray(label)
9✔
430
    except ValueError:
×
431
        raise ValueError("label must be a string or array_like")
×
432

433
    # Turn the data into a 3D array of appropriate shape
434
    # TODO: allow more sophisticated broadcasting (and error checking)
435
    try:
9✔
436
        if ninputs > 0 and noutputs > 0:
9✔
437
            if line_labels.ndim == 1 and line_labels.size == ntraces:
9✔
438
                line_labels = line_labels.reshape(ntraces, 1, 1)
9✔
439
                line_labels = np.broadcast_to(
9✔
440
                    line_labels, (ntraces, ninputs, noutputs))
441
            else:
442
                line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
9✔
443
    except ValueError:
×
444
        if line_labels.shape[0] != ntraces:
×
445
            raise ValueError("number of labels must match number of traces")
×
446
        else:
447
            raise ValueError("labels must be given for each input/output pair")
×
448

449
    return line_labels
9✔
450

451

452
# Get labels for all lines in an axes
453
def _get_line_labels(ax, use_color=True):
9✔
454
    labels_colors, lines = [], []
9✔
455
    last_color, counter = None, 0       # label unknown systems
9✔
456
    for i, line in enumerate(ax.get_lines()):
9✔
457
        label = line.get_label()
9✔
458
        color = line.get_color()
9✔
459
        if use_color and label.startswith("Unknown"):
9✔
460
            label = f"Unknown-{counter}"
×
461
            if last_color != color:
×
462
                counter += 1
×
463
            last_color = color
×
464
        elif label[0] == '_':
9✔
465
            continue
9✔
466

467
        if (label, color) not in labels_colors:
9✔
468
            lines.append(line)
9✔
469
            labels_colors.append((label, color))
9✔
470

471
    return lines, [label for label, color in labels_colors]
9✔
472

473

474
def _process_legend_keywords(
9✔
475
        kwargs, shape=None, default_loc='center right'):
476
    legend_loc = kwargs.pop('legend_loc', None)
9✔
477
    if shape is None and 'legend_map' in kwargs:
9✔
478
        raise TypeError("unexpected keyword argument 'legend_map'")
9✔
479
    else:
480
        legend_map = kwargs.pop('legend_map', None)
9✔
481
    show_legend = kwargs.pop('show_legend', None)
9✔
482

483
    # If legend_loc or legend_map were given, always show the legend
484
    if legend_loc is False or legend_map is False:
9✔
485
        if show_legend is True:
9✔
486
            warnings.warn(
×
487
                "show_legend ignored; legend_loc or legend_map was given")
488
        show_legend = False
9✔
489
        legend_loc = legend_map = None
9✔
490
    elif legend_loc is not None or legend_map is not None:
9✔
491
        if show_legend is False:
9✔
492
            warnings.warn(
×
493
                "show_legend ignored; legend_loc or legend_map was given")
494
        show_legend = True
9✔
495

496
    if legend_loc is None:
9✔
497
        legend_loc = default_loc
9✔
498
    elif not isinstance(legend_loc, (int, str)):
9✔
499
        raise ValueError("legend_loc must be string or int")
×
500

501
    # Make sure the legend map is the right size
502
    if legend_map is not None:
9✔
503
        legend_map = np.atleast_2d(legend_map)
9✔
504
        if legend_map.shape != shape:
9✔
505
            raise ValueError("legend_map shape just match axes shape")
×
506

507
    return legend_loc, legend_map, show_legend
9✔
508

509

510
# Utility function to make legend labels
511
def _make_legend_labels(labels, ignore_common=False):
9✔
512
    if len(labels) == 1:
9✔
513
        return labels
9✔
514

515
    # Look for a common prefix (up to a space)
516
    common_prefix = commonprefix(labels)
9✔
517
    last_space = common_prefix.rfind(', ')
9✔
518
    if last_space < 0 or ignore_common:
9✔
519
        common_prefix = ''
9✔
520
    elif last_space > 0:
9✔
521
        common_prefix = common_prefix[:last_space + 2]
9✔
522
    prefix_len = len(common_prefix)
9✔
523

524
    # Look for a common suffix (up to a space)
525
    common_suffix = commonprefix(
9✔
526
        [label[::-1] for label in labels])[::-1]
527
    suffix_len = len(common_suffix)
9✔
528
    # Only chop things off after a comma or space
529
    while suffix_len > 0 and common_suffix[-suffix_len] != ',':
9✔
530
        suffix_len -= 1
9✔
531

532
    # Strip the labels of common information
533
    if suffix_len > 0 and not ignore_common:
9✔
534
        labels = [label[prefix_len:-suffix_len] for label in labels]
9✔
535
    else:
536
        labels = [label[prefix_len:] for label in labels]
9✔
537

538
    return labels
9✔
539

540

541
def _update_plot_title(
9✔
542
        title, fig=None, frame='axes', use_existing=True, **kwargs):
543
    if title is False or title is None:
9✔
544
        return
9✔
545
    if fig is None:
9✔
546
        fig = plt.gcf()
9✔
547
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
548

549
    if use_existing:
9✔
550
        # Get the current title, if it exists
551
        old_title = None if fig._suptitle is None else fig._suptitle._text
9✔
552

553
        if old_title is not None:
9✔
554
            # Find the common part of the titles
555
            common_prefix = commonprefix([old_title, title])
9✔
556

557
            # Back up to the last space
558
            last_space = common_prefix.rfind(' ')
9✔
559
            if last_space > 0:
9✔
560
                common_prefix = common_prefix[:last_space]
9✔
561
            common_len = len(common_prefix)
9✔
562

563
            # Add the new part of the title (usually the system name)
564
            if old_title[common_len:] != title[common_len:]:
9✔
565
                separator = ',' if len(common_prefix) > 0 else ';'
9✔
566
                title = old_title + separator + title[common_len:]
9✔
567

568
    if frame == 'figure':
9✔
569
        with plt.rc_context(rcParams):
9✔
570
            fig.suptitle(title, **kwargs)
9✔
571

572
    elif frame == 'axes':
9✔
573
        with plt.rc_context(rcParams):
9✔
574
            fig.suptitle(title, **kwargs)           # Place title in center
9✔
575
            plt.tight_layout()                      # Put everything into place
9✔
576
            xc, _ = _find_axes_center(fig, fig.get_axes())
9✔
577
            fig.suptitle(title, x=xc, **kwargs)     # Redraw title, centered
9✔
578

579
    else:
580
        raise ValueError(f"unknown frame '{frame}'")
9✔
581

582

583
def _find_axes_center(fig, axs):
9✔
584
    """Find the midpoint between axes in display coordinates.
585

586
    This function finds the middle of a plot as defined by a set of axes.
587

588
    """
589
    inv_transform = fig.transFigure.inverted()
9✔
590
    xlim = ylim = [1, 0]
9✔
591
    for ax in axs:
9✔
592
        ll = inv_transform.transform(ax.transAxes.transform((0, 0)))
9✔
593
        ur = inv_transform.transform(ax.transAxes.transform((1, 1)))
9✔
594

595
        xlim = [min(ll[0], xlim[0]), max(ur[0], xlim[1])]
9✔
596
        ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]
9✔
597

598
    return (np.sum(xlim)/2, np.sum(ylim)/2)
9✔
599

600

601
# Internal function to add arrows to a curve
602
def _add_arrows_to_line2D(
9✔
603
        axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
604
        arrowstyle='-|>', arrowsize=1, dir=1):
605
    """
606
    Add arrows to a matplotlib.lines.Line2D at selected locations.
607

608
    Parameters
609
    ----------
610
    axes: Axes object as returned by axes command (or gca)
611
    line: Line2D object as returned by plot command
612
    arrow_locs: list of locations where to insert arrows, % of total length
613
    arrowstyle: style of the arrow
614
    arrowsize: size of the arrow
615

616
    Returns
617
    -------
618
    arrows : list of arrows
619

620
    Notes
621
    -----
622
    Based on https://stackoverflow.com/questions/26911898/
623

624
    """
625
    # Get the coordinates of the line, in plot coordinates
626
    if not isinstance(line, mpl.lines.Line2D):
9✔
627
        raise ValueError("expected a matplotlib.lines.Line2D object")
×
628
    x, y = line.get_xdata(), line.get_ydata()
9✔
629

630
    # Determine the arrow properties
631
    arrow_kw = {"arrowstyle": arrowstyle}
9✔
632

633
    color = line.get_color()
9✔
634
    use_multicolor_lines = isinstance(color, np.ndarray)
9✔
635
    if use_multicolor_lines:
9✔
636
        raise NotImplementedError("multi-color lines not supported")
637
    else:
638
        arrow_kw['color'] = color
9✔
639

640
    linewidth = line.get_linewidth()
9✔
641
    if isinstance(linewidth, np.ndarray):
9✔
642
        raise NotImplementedError("multi-width lines not supported")
643
    else:
644
        arrow_kw['linewidth'] = linewidth
9✔
645

646
    # Figure out the size of the axes (length of diagonal)
647
    xlim, ylim = axes.get_xlim(), axes.get_ylim()
9✔
648
    ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
9✔
649
    diag = np.linalg.norm(ul - lr)
9✔
650

651
    # Compute the arc length along the curve
652
    s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
9✔
653

654
    # Truncate the number of arrows if the curve is short
655
    # TODO: figure out a smarter way to do this
656
    frac = min(s[-1] / diag, 1)
9✔
657
    if len(arrow_locs) and frac < 0.05:
9✔
658
        arrow_locs = []         # too short; no arrows at all
9✔
659
    elif len(arrow_locs) and frac < 0.2:
9✔
660
        arrow_locs = [0.5]      # single arrow in the middle
9✔
661

662
    # Plot the arrows (and return list if patches)
663
    arrows = []
9✔
664
    for loc in arrow_locs:
9✔
665
        n = np.searchsorted(s, s[-1] * loc)
9✔
666

667
        if dir == 1 and n == 0:
9✔
668
            # Move the arrow forward by one if it is at start of a segment
669
            n = 1
9✔
670

671
        # Place the head of the arrow at the desired location
672
        arrow_head = [x[n], y[n]]
9✔
673
        arrow_tail = [x[n - dir], y[n - dir]]
9✔
674

675
        p = mpl.patches.FancyArrowPatch(
9✔
676
            arrow_tail, arrow_head, transform=axes.transData, lw=0,
677
            **arrow_kw)
678
        axes.add_patch(p)
9✔
679
        arrows.append(p)
9✔
680
    return arrows
9✔
681

682

683
def _get_color_offset(ax, color_cycle=None):
9✔
684
    """Get color offset based on current lines.
685

686
    This function determines that the current offset is for the next color
687
    to use based on current colors in a plot.
688

689
    Parameters
690
    ----------
691
    ax : `matplotlib.axes.Axes`
692
        Axes containing already plotted lines.
693
    color_cycle : list of matplotlib color specs, optional
694
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
695
        color cycle.
696

697
    Returns
698
    -------
699
    color_offset : matplotlib color spec
700
        Starting color for next line to be drawn.
701
    color_cycle : list of matplotlib color specs
702
        Color cycle used to determine colors.
703

704
    """
705
    if color_cycle is None:
9✔
706
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
707

708
    color_offset = 0
9✔
709
    if len(ax.lines) > 0:
9✔
710
        last_color = ax.lines[-1].get_color()
9✔
711
        if last_color in color_cycle:
9✔
712
            color_offset = color_cycle.index(last_color) + 1
9✔
713

714
    return color_offset % len(color_cycle), color_cycle
9✔
715

716

717
def _get_color(
9✔
718
        colorspec, offset=None, fmt=None, ax=None, lines=None,
719
        color_cycle=None):
720
    """Get color to use for plotting line.
721

722
    This function returns the color to be used for the line to be drawn (or
723
    None if the default color cycle for the axes should be used).
724

725
    Parameters
726
    ----------
727
    colorspec : matplotlib color specification
728
        User-specified color (or None).
729
    offset : int, optional
730
        Offset into the color cycle (for multi-trace plots).
731
    fmt : str, optional
732
        Format string passed to plotting command.
733
    ax : `matplotlib.axes.Axes`, optional
734
        Axes containing already plotted lines.
735
    lines : list of matplotlib.lines.Line2D, optional
736
        List of plotted lines.  If not given, use ax.get_lines().
737
    color_cycle : list of matplotlib color specs, optional
738
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
739
        color cycle.
740

741
    Returns
742
    -------
743
    color : matplotlib color spec
744
        Color to use for this line (or None for matplotlib default).
745

746
    """
747
    # See if the color was explicitly specified by the user
748
    if isinstance(colorspec, dict):
9✔
749
        if 'color' in colorspec:
9✔
750
            return colorspec.pop('color')
9✔
751
    elif fmt is not None and \
9✔
752
         [isinstance(arg, str) and
753
          any([c in arg for c in "bgrcmykw#"]) for arg in fmt]:
754
        return None             # *fmt will set the color
9✔
755
    elif colorspec != None:
9✔
756
        return colorspec
9✔
757

758
    # Figure out what color cycle to use, if not given by caller
759
    if color_cycle == None:
9✔
760
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
761

762
    # Find the lines that we should pay attention to
763
    if lines is None and ax is not None:
9✔
764
        lines = ax.lines
9✔
765

766
    # If we were passed a set of lines, try to increment color from previous
767
    if offset is not None:
9✔
768
        return color_cycle[offset]
9✔
769
    elif lines is not None:
9✔
770
        color_offset = 0
9✔
771
        if len(ax.lines) > 0:
9✔
772
            last_color = ax.lines[-1].get_color()
9✔
773
            if last_color in color_cycle:
9✔
774
                color_offset = color_cycle.index(last_color) + 1
9✔
775
        color_offset = color_offset % len(color_cycle)
9✔
776
        return color_cycle[color_offset]
9✔
777
    else:
778
        return None
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc