• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 10399836602

15 Aug 2024 06:31AM UTC coverage: 94.694% (+0.001%) from 94.693%
10399836602

push

github

web-flow
Merge pull request #1040 from murrayrm/tickmark_labels-08Aug2024

Update shared axes processing in plot_time_response

9138 of 9650 relevant lines covered (94.69%)

8.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.97
control/ctrlplot.py
1
# ctrlplot.py - utility functions for plotting
2
# Richard M. Murray, 14 Jun 2024
3
#
4
# Collection of functions that are used by various plotting functions.
5

6
# Code pattern for control system plotting functions:
7
#
8
# def name_plot(sysdata, *fmt, plot=None, **kwargs):
9
#     # Process keywords and set defaults
10
#     ax = kwargs.pop('ax', None)
11
#     color = kwargs.pop('color', None)
12
#     label = kwargs.pop('label', None)
13
#     rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
14
#
15
#     # Make sure all keyword arguments were processed (if not checked later)
16
#     if kwargs:
17
#         raise TypeError("unrecognized keywords: ", str(kwargs))
18
#
19
#     # Process the data (including generating responses for systems)
20
#     sysdata = list(sysdata)
21
#     if any([isinstance(sys, InputOutputSystem) for sys in sysdata]):
22
#         data = name_response(sysdata)
23
#     nrows = max([data.noutputs for data in sysdata])
24
#     ncols = max([data.ninputs for data in sysdata])
25
#
26
#     # Legacy processing of plot keyword
27
#     if plot is False:
28
#         return data.x, data.y
29
#
30
#     # Figure out the shape of the plot and find/create axes
31
#     fig, ax_array = _process_ax_keyword(ax, (nrows, ncols), rcParams)
32
#     legend_loc, legend_map, show_legend = _process_legend_keywords(
33
#         kwargs, (nrows, ncols), 'center right')
34
#
35
#     # Customize axes (curvilinear grids, shared axes, etc)
36
#
37
#     # Plot the data
38
#     lines = np.full(ax_array.shape, [])
39
#     line_labels = _process_line_labels(label, ntraces, nrows, ncols)
40
#     color_offset, color_cycle = _get_color_offset(ax)
41
#     for i, j in itertools.product(range(nrows), range(ncols)):
42
#         ax = ax_array[i, j]
43
#         for k in range(ntraces):
44
#             if color is None:
45
#                 color = _get_color(
46
#                     color, fmt=fmt, offset=k, color_cycle=color_cycle)
47
#             label = line_labels[k, i, j]
48
#             lines[i, j] += ax.plot(data.x, data.y, color=color, label=label)
49
#
50
#     # Customize and label the axes
51
#     for i, j in itertools.product(range(nrows), range(ncols)):
52
#         ax_array[i, j].set_xlabel("x label")
53
#         ax_array[i, j].set_ylabel("y label")
54
#
55
#     # Create legends
56
#     if show_legend != False:
57
#         legend_array = np.full(ax_array.shape, None, dtype=object)
58
#         for i, j in itertools.product(range(nrows), range(ncols)):
59
#             if legend_map[i, j] is not None:
60
#                 lines = ax_array[i, j].get_lines()
61
#                 labels = _make_legend_labels(lines)
62
#                 if len(labels) > 1:
63
#                     legend_array[i, j] = ax.legend(
64
#                         lines, labels, loc=legend_map[i, j])
65
#     else:
66
#         legend_array = None
67
#
68
#     # Update the plot title (only if ax was not given)
69
#     sysnames = [response.sysname for response in data]
70
#     if ax is None and title is None:
71
#         title = "Name plot for " + ", ".join(sysnames)
72
#         _update_plot_title(title, fig, rcParams=rcParams)
73
#     elif ax == None:
74
#         _update_plot_title(title, fig, rcParams=rcParams, use_existing=False)
75
#
76
#     # Legacy processing of plot keyword
77
#     if plot is True:
78
#         return data
79
#
80
#     return ControlPlot(lines, ax_array, fig, legend=legend_map)
81

82
import itertools
9✔
83
import warnings
9✔
84
from os.path import commonprefix
9✔
85

86
import matplotlib as mpl
9✔
87
import matplotlib.pyplot as plt
9✔
88
import numpy as np
9✔
89

90
from . import config
9✔
91

92
__all__ = [
9✔
93
    'ControlPlot', 'suptitle', 'get_plot_axes', 'pole_zero_subplots',
94
    'rcParams', 'reset_rcParams']
95

96
#
97
# Style parameters
98
#
99

100
rcParams_default = {
9✔
101
    'axes.labelsize': 'small',
102
    'axes.titlesize': 'small',
103
    'figure.titlesize': 'medium',
104
    'legend.fontsize': 'x-small',
105
    'xtick.labelsize': 'small',
106
    'ytick.labelsize': 'small',
107
}
108
_ctrlplot_rcParams = rcParams_default.copy()    # provide access inside module
9✔
109
rcParams = _ctrlplot_rcParams                   # provide access outside module
9✔
110

111
_ctrlplot_defaults = {'ctrlplot.rcParams': _ctrlplot_rcParams}
9✔
112

113

114
#
115
# Control figure
116
#
117

118
class ControlPlot(object):
9✔
119
    """A class for returning control figures.
120

121
    This class is used as the return type for control plotting functions.
122
    It contains the information required to access portions of the plot
123
    that the user might want to adjust, as well as providing methods to
124
    modify some of the properties of the plot.
125

126
    A control figure consists of a :class:`matplotlib.figure.Figure` with
127
    an array of :class:`matplotlib.axes.Axes`.  Each axes in the figure has
128
    a number of lines that represent the data for the plot.  There may also
129
    be a legend present in one or more of the axes.
130

131
    Attributes
132
    ----------
133
    lines : array of list of :class:`matplotlib:Line2D`
134
        Array of Line2D objects for each line in the plot.  Generally, the
135
        shape of the array matches the subplots shape and the value of the
136
        array is a list of Line2D objects in that subplot.  Some plotting
137
        functions will return variants of this structure, as described in
138
        the individual documentation for the functions.
139
    axes : 2D array of :class:`matplotlib:Axes`
140
        Array of Axes objects for each subplot in the plot.
141
    figure : :class:`matplotlib:Figure`
142
        Figure on which the Axes are drawn.
143
    legend : :class:`matplotlib:.legend.Legend` (instance or ndarray)
144
        Legend object(s) for the plot.  If more than one legend is
145
        included, this will be an array with each entry being either None
146
        (for no legend) or a legend object.
147

148
    """
149
    def __init__(self, lines, axes=None, figure=None, legend=None):
9✔
150
        self.lines = lines
9✔
151
        if axes is None:
9✔
152
            _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
153
            axes = _get_axes(lines)
9✔
154
        self.axes = np.atleast_2d(axes)
9✔
155
        if figure is None:
9✔
156
            figure = self.axes[0, 0].figure
9✔
157
        self.figure = figure
9✔
158
        self.legend = legend
9✔
159

160
    # Implement methods and properties to allow legacy interface (np.array)
161
    __iter__ = lambda self: self.lines
9✔
162
    __len__ = lambda self: len(self.lines)
9✔
163
    def __getitem__(self, item):
9✔
164
        warnings.warn(
9✔
165
            "return of Line2D objects from plot function is deprecated in "
166
            "favor of ControlPlot; use out.lines to access Line2D objects",
167
            category=FutureWarning)
168
        return self.lines[item]
9✔
169
    def __setitem__(self, item, val):
9✔
170
        self.lines[item] = val
×
171
    shape = property(lambda self: self.lines.shape, None)
9✔
172
    def reshape(self, *args):
9✔
173
        return self.lines.reshape(*args)
9✔
174

175
    def set_plot_title(self, title, frame='axes'):
9✔
176
        """Set the title for a control plot.
177

178
        This is a wrapper for the matplotlib `suptitle` function, but by
179
        setting ``frame`` to 'axes' (default) then the title is centered on
180
        the midpoint of the axes in the figure, rather than the center of
181
        the figure.  This usually looks better (particularly with
182
        multi-panel plots), though it takes longer to render.
183

184
        Parameters
185
        ----------
186
        title : str
187
            Title text.
188
        fig : Figure, optional
189
            Matplotlib figure.  Defaults to current figure.
190
        frame : str, optional
191
            Coordinate frame to use for centering: 'axes' (default) or 'figure'.
192
        **kwargs : :func:`matplotlib.pyplot.suptitle` keywords, optional
193
            Additional keywords (passed to matplotlib).
194

195
        """
196
        _update_plot_title(
9✔
197
            title, fig=self.figure, frame=frame, use_existing=False)
198

199
#
200
# User functions
201
#
202
# The functions below can be used by users to modify control plots or get
203
# information about them.
204
#
205

206
def suptitle(
9✔
207
        title, fig=None, frame='axes', **kwargs):
208
    """Add a centered title to a figure.
209

210
    .. deprecated:: 0.10.1
211
        Use :func:`ControlPlot.set_plot_title`.
212

213
    """
214
    warnings.warn(
9✔
215
        "suptitle() is deprecated; use cplt.set_plot_title()", FutureWarning)
216
    _update_plot_title(
9✔
217
        title, fig=fig, frame=frame, use_existing=False, **kwargs)
218

219

220
# Create vectorized function to find axes from lines
221
def get_plot_axes(line_array):
9✔
222
    """Get a list of axes from an array of lines.
223

224
    .. deprecated:: 0.10.1
225
        This function will be removed in a future version of python-control.
226
        Use `cplt.axes` to obtain axes for an instance of :class:`ControlPlot`.
227

228
    This function can be used to return the set of axes corresponding
229
    to the line array that is returned by `time_response_plot`.  This
230
    is useful for generating an axes array that can be passed to
231
    subsequent plotting calls.
232

233
    Parameters
234
    ----------
235
    line_array : array of list of Line2D
236
        A 2D array with elements corresponding to a list of lines appearing
237
        in an axes, matching the return type of a time response data plot.
238

239
    Returns
240
    -------
241
    axes_array : array of list of Axes
242
        A 2D array with elements corresponding to the Axes associated with
243
        the lines in `line_array`.
244

245
    Notes
246
    -----
247
    Only the first element of each array entry is used to determine the axes.
248

249
    """
250
    warnings.warn(
9✔
251
        "get_plot_axes() is deprecated; use cplt.axes()", FutureWarning)
252
    _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
253
    if isinstance(line_array, ControlPlot):
9✔
254
        return _get_axes(line_array.lines)
9✔
255
    else:
256
        return _get_axes(line_array)
9✔
257

258

259
def pole_zero_subplots(
9✔
260
        nrows, ncols, grid=None, dt=None, fig=None, scaling=None,
261
        rcParams=None):
262
    """Create axes for pole/zero plot.
263

264
    Parameters
265
    ----------
266
    nrows, ncols : int
267
        Number of rows and columns.
268
    grid : True, False, or 'empty', optional
269
        Grid style to use.  Can also be a list, in which case each subplot
270
        will have a different style (columns then rows).
271
    dt : timebase, option
272
        Timebase for each subplot (or a list of timebases).
273
    scaling : 'auto', 'equal', or None
274
        Scaling to apply to the subplots.
275
    fig : :class:`matplotlib.figure.Figure`
276
        Figure to use for creating subplots.
277
    rcParams : dict
278
        Override the default parameters used for generating plots.
279
        Default is set up config.default['ctrlplot.rcParams'].
280

281
    Returns
282
    -------
283
    ax_array : array
284
        2D array of axes
285

286
    """
287
    from .grid import nogrid, sgrid, zgrid
9✔
288
    from .iosys import isctime
9✔
289

290
    if fig is None:
9✔
291
        fig = plt.gcf()
9✔
292
    rcParams = config._get_param('ctrlplot', 'rcParams', rcParams)
9✔
293

294
    if not isinstance(grid, list):
9✔
295
        grid = [grid] * nrows * ncols
9✔
296
    if not isinstance(dt, list):
9✔
297
        dt = [dt] * nrows * ncols
9✔
298

299
    ax_array = np.full((nrows, ncols), None)
9✔
300
    index = 0
9✔
301
    with plt.rc_context(rcParams):
9✔
302
        for row, col in itertools.product(range(nrows), range(ncols)):
9✔
303
            match grid[index], isctime(dt=dt[index]):
9✔
304
                case 'empty', _:        # empty grid
9✔
305
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
306

307
                case True, True:        # continuous time grid
9✔
308
                    ax_array[row, col], _ = sgrid(
9✔
309
                        (nrows, ncols, index+1), scaling=scaling)
310

311
                case True, False:       # discrete time grid
9✔
312
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
313
                    zgrid(ax=ax_array[row, col], scaling=scaling)
9✔
314

315
                case False | None, _:   # no grid (just stability boundaries)
9✔
316
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
317
                    nogrid(
9✔
318
                        ax=ax_array[row, col], dt=dt[index], scaling=scaling)
319
            index += 1
9✔
320
    return ax_array
9✔
321

322

323
def reset_rcParams():
9✔
324
    """Reset rcParams to default values for control plots."""
325
    _ctrlplot_rcParams.update(rcParams_default)
9✔
326

327

328
#
329
# Utility functions
330
#
331
# These functions are used by plotting routines to provide a consistent way
332
# of processing and displaying information.
333
#
334

335
def _process_ax_keyword(
9✔
336
        axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False,
337
        create_axes=True, sharex=False, sharey=False):
338
    """Process ax keyword to plotting commands.
339

340
    This function processes the `ax` keyword to plotting commands.  If no
341
    ax keyword is passed, the current figure is checked to see if it has
342
    the correct shape.  If the shape matches the desired shape, then the
343
    current figure and axes are returned.  Otherwise a new figure is
344
    created with axes of the desired shape.
345

346
    If `create_axes` is False and a new/empty figure is returned, then axs
347
    is an array of the proper shape but None for each element.  This allows
348
    the calling function to do the actual axis creation (needed for
349
    curvilinear grids that use the AxisArtist module).
350

351
    Legacy behavior: some of the older plotting commands use a axes label
352
    to identify the proper axes for plotting.  This behavior is supported
353
    through the use of the label keyword, but will only work if shape ==
354
    (1, 1) and squeeze == True.
355

356
    """
357
    if axs is None:
9✔
358
        fig = plt.gcf()         # get current figure (or create new one)
9✔
359
        axs = fig.get_axes()
9✔
360

361
        # Check to see if axes are the right shape; if not, create new figure
362
        # Note: can't actually check the shape, just the total number of axes
363
        if len(axs) != np.prod(shape):
9✔
364
            with plt.rc_context(rcParams):
9✔
365
                if len(axs) != 0 and create_axes:
9✔
366
                    # Create a new figure
367
                    fig, axs = plt.subplots(
9✔
368
                        *shape, sharex=sharex, sharey=sharey, squeeze=False)
369
                elif create_axes:
9✔
370
                    # Create new axes on (empty) figure
371
                    axs = fig.subplots(
9✔
372
                        *shape, sharex=sharex, sharey=sharey, squeeze=False)
373
                else:
374
                    # Create an empty array and let user create axes
375
                    axs = np.full(shape, None)
9✔
376
            if create_axes:     # if not creating axes, leave these to caller
9✔
377
                fig.set_layout_engine('tight')
9✔
378
                fig.align_labels()
9✔
379

380
        else:
381
            # Use the existing axes, properly reshaped
382
            axs = np.asarray(axs).reshape(*shape)
9✔
383

384
            if clear_text:
9✔
385
                # Clear out any old text from the current figure
386
                for text in fig.texts:
9✔
387
                    text.set_visible(False)     # turn off the text
9✔
388
                    del text                    # get rid of it completely
9✔
389
    else:
390
        axs = np.atleast_1d(axs)
9✔
391
        try:
9✔
392
            axs = axs.reshape(shape)
9✔
393
        except ValueError:
9✔
394
            raise ValueError(
9✔
395
                "specified axes are not the right shape; "
396
                f"got {axs.shape} but expecting {shape}")
397
        fig = axs[0, 0].figure
9✔
398

399
    # Process the squeeze keyword
400
    if squeeze and shape == (1, 1):
9✔
401
        axs = axs[0, 0]         # Just return the single axes object
9✔
402
    elif squeeze:
9✔
403
        axs = axs.squeeze()
×
404

405
    return fig, axs
9✔
406

407

408
# Turn label keyword into array indexed by trace, output, input
409
# TODO: move to ctrlutil.py and update parameter names to reflect general use
410
def _process_line_labels(label, ntraces=1, ninputs=0, noutputs=0):
9✔
411
    if label is None:
9✔
412
        return None
9✔
413

414
    if isinstance(label, str):
9✔
415
        label = [label] * ntraces          # single label for all traces
9✔
416

417
    # Convert to an ndarray, if not done already
418
    try:
9✔
419
        line_labels = np.asarray(label)
9✔
420
    except ValueError:
×
421
        raise ValueError("label must be a string or array_like")
×
422

423
    # Turn the data into a 3D array of appropriate shape
424
    # TODO: allow more sophisticated broadcasting (and error checking)
425
    try:
9✔
426
        if ninputs > 0 and noutputs > 0:
9✔
427
            if line_labels.ndim == 1 and line_labels.size == ntraces:
9✔
428
                line_labels = line_labels.reshape(ntraces, 1, 1)
9✔
429
                line_labels = np.broadcast_to(
9✔
430
                    line_labels, (ntraces, ninputs, noutputs))
431
            else:
432
                line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
9✔
433
    except ValueError:
×
434
        if line_labels.shape[0] != ntraces:
×
435
            raise ValueError("number of labels must match number of traces")
×
436
        else:
437
            raise ValueError("labels must be given for each input/output pair")
×
438

439
    return line_labels
9✔
440

441

442
# Get labels for all lines in an axes
443
def _get_line_labels(ax, use_color=True):
9✔
444
    labels_colors, lines = [], []
9✔
445
    last_color, counter = None, 0       # label unknown systems
9✔
446
    for i, line in enumerate(ax.get_lines()):
9✔
447
        label = line.get_label()
9✔
448
        color = line.get_color()
9✔
449
        if use_color and label.startswith("Unknown"):
9✔
450
            label = f"Unknown-{counter}"
×
451
            if last_color != color:
×
452
                counter += 1
×
453
            last_color = color
×
454
        elif label[0] == '_':
9✔
455
            continue
9✔
456

457
        if (label, color) not in labels_colors:
9✔
458
            lines.append(line)
9✔
459
            labels_colors.append((label, color))
9✔
460

461
    return lines, [label for label, color in labels_colors]
9✔
462

463

464
def _process_legend_keywords(
9✔
465
        kwargs, shape=None, default_loc='center right'):
466
    legend_loc = kwargs.pop('legend_loc', None)
9✔
467
    if shape is None and 'legend_map' in kwargs:
9✔
468
        raise TypeError("unexpected keyword argument 'legend_map'")
9✔
469
    else:
470
        legend_map = kwargs.pop('legend_map', None)
9✔
471
    show_legend = kwargs.pop('show_legend', None)
9✔
472

473
    # If legend_loc or legend_map were given, always show the legend
474
    if legend_loc is False or legend_map is False:
9✔
475
        if show_legend is True:
9✔
476
            warnings.warn(
×
477
                "show_legend ignored; legend_loc or legend_map was given")
478
        show_legend = False
9✔
479
        legend_loc = legend_map = None
9✔
480
    elif legend_loc is not None or legend_map is not None:
9✔
481
        if show_legend is False:
9✔
482
            warnings.warn(
×
483
                "show_legend ignored; legend_loc or legend_map was given")
484
        show_legend = True
9✔
485

486
    if legend_loc is None:
9✔
487
        legend_loc = default_loc
9✔
488
    elif not isinstance(legend_loc, (int, str)):
9✔
489
        raise ValueError("legend_loc must be string or int")
×
490

491
    # Make sure the legend map is the right size
492
    if legend_map is not None:
9✔
493
        legend_map = np.atleast_2d(legend_map)
9✔
494
        if legend_map.shape != shape:
9✔
495
            raise ValueError("legend_map shape just match axes shape")
×
496

497
    return legend_loc, legend_map, show_legend
9✔
498

499

500
# Utility function to make legend labels
501
def _make_legend_labels(labels, ignore_common=False):
9✔
502
    if len(labels) == 1:
9✔
503
        return labels
9✔
504

505
    # Look for a common prefix (up to a space)
506
    common_prefix = commonprefix(labels)
9✔
507
    last_space = common_prefix.rfind(', ')
9✔
508
    if last_space < 0 or ignore_common:
9✔
509
        common_prefix = ''
9✔
510
    elif last_space > 0:
9✔
511
        common_prefix = common_prefix[:last_space + 2]
9✔
512
    prefix_len = len(common_prefix)
9✔
513

514
    # Look for a common suffix (up to a space)
515
    common_suffix = commonprefix(
9✔
516
        [label[::-1] for label in labels])[::-1]
517
    suffix_len = len(common_suffix)
9✔
518
    # Only chop things off after a comma or space
519
    while suffix_len > 0 and common_suffix[-suffix_len] != ',':
9✔
520
        suffix_len -= 1
9✔
521

522
    # Strip the labels of common information
523
    if suffix_len > 0 and not ignore_common:
9✔
524
        labels = [label[prefix_len:-suffix_len] for label in labels]
9✔
525
    else:
526
        labels = [label[prefix_len:] for label in labels]
9✔
527

528
    return labels
9✔
529

530

531
def _update_plot_title(
9✔
532
        title, fig=None, frame='axes', use_existing=True, **kwargs):
533
    if title is False or title is None:
9✔
534
        return
9✔
535
    if fig is None:
9✔
536
        fig = plt.gcf()
9✔
537
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
538

539
    if use_existing:
9✔
540
        # Get the current title, if it exists
541
        old_title = None if fig._suptitle is None else fig._suptitle._text
9✔
542

543
        if old_title is not None:
9✔
544
            # Find the common part of the titles
545
            common_prefix = commonprefix([old_title, title])
9✔
546

547
            # Back up to the last space
548
            last_space = common_prefix.rfind(' ')
9✔
549
            if last_space > 0:
9✔
550
                common_prefix = common_prefix[:last_space]
9✔
551
            common_len = len(common_prefix)
9✔
552

553
            # Add the new part of the title (usually the system name)
554
            if old_title[common_len:] != title[common_len:]:
9✔
555
                separator = ',' if len(common_prefix) > 0 else ';'
9✔
556
                title = old_title + separator + title[common_len:]
9✔
557

558
    if frame == 'figure':
9✔
559
        with plt.rc_context(rcParams):
9✔
560
            fig.suptitle(title, **kwargs)
9✔
561

562
    elif frame == 'axes':
9✔
563
        with plt.rc_context(rcParams):
9✔
564
            fig.suptitle(title, **kwargs)           # Place title in center
9✔
565
            plt.tight_layout()                      # Put everything into place
9✔
566
            xc, _ = _find_axes_center(fig, fig.get_axes())
9✔
567
            fig.suptitle(title, x=xc, **kwargs)     # Redraw title, centered
9✔
568

569
    else:
570
        raise ValueError(f"unknown frame '{frame}'")
9✔
571

572

573
def _find_axes_center(fig, axs):
9✔
574
    """Find the midpoint between axes in display coordinates.
575

576
    This function finds the middle of a plot as defined by a set of axes.
577

578
    """
579
    inv_transform = fig.transFigure.inverted()
9✔
580
    xlim = ylim = [1, 0]
9✔
581
    for ax in axs:
9✔
582
        ll = inv_transform.transform(ax.transAxes.transform((0, 0)))
9✔
583
        ur = inv_transform.transform(ax.transAxes.transform((1, 1)))
9✔
584

585
        xlim = [min(ll[0], xlim[0]), max(ur[0], xlim[1])]
9✔
586
        ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]
9✔
587

588
    return (np.sum(xlim)/2, np.sum(ylim)/2)
9✔
589

590

591
# Internal function to add arrows to a curve
592
def _add_arrows_to_line2D(
9✔
593
        axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
594
        arrowstyle='-|>', arrowsize=1, dir=1):
595
    """
596
    Add arrows to a matplotlib.lines.Line2D at selected locations.
597

598
    Parameters
599
    ----------
600
    axes: Axes object as returned by axes command (or gca)
601
    line: Line2D object as returned by plot command
602
    arrow_locs: list of locations where to insert arrows, % of total length
603
    arrowstyle: style of the arrow
604
    arrowsize: size of the arrow
605

606
    Returns
607
    -------
608
    arrows: list of arrows
609

610
    Based on https://stackoverflow.com/questions/26911898/
611

612
    """
613
    # Get the coordinates of the line, in plot coordinates
614
    if not isinstance(line, mpl.lines.Line2D):
9✔
615
        raise ValueError("expected a matplotlib.lines.Line2D object")
×
616
    x, y = line.get_xdata(), line.get_ydata()
9✔
617

618
    # Determine the arrow properties
619
    arrow_kw = {"arrowstyle": arrowstyle}
9✔
620

621
    color = line.get_color()
9✔
622
    use_multicolor_lines = isinstance(color, np.ndarray)
9✔
623
    if use_multicolor_lines:
9✔
624
        raise NotImplementedError("multicolor lines not supported")
625
    else:
626
        arrow_kw['color'] = color
9✔
627

628
    linewidth = line.get_linewidth()
9✔
629
    if isinstance(linewidth, np.ndarray):
9✔
630
        raise NotImplementedError("multiwidth lines not supported")
631
    else:
632
        arrow_kw['linewidth'] = linewidth
9✔
633

634
    # Figure out the size of the axes (length of diagonal)
635
    xlim, ylim = axes.get_xlim(), axes.get_ylim()
9✔
636
    ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
9✔
637
    diag = np.linalg.norm(ul - lr)
9✔
638

639
    # Compute the arc length along the curve
640
    s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
9✔
641

642
    # Truncate the number of arrows if the curve is short
643
    # TODO: figure out a smarter way to do this
644
    frac = min(s[-1] / diag, 1)
9✔
645
    if len(arrow_locs) and frac < 0.05:
9✔
646
        arrow_locs = []         # too short; no arrows at all
9✔
647
    elif len(arrow_locs) and frac < 0.2:
9✔
648
        arrow_locs = [0.5]      # single arrow in the middle
9✔
649

650
    # Plot the arrows (and return list if patches)
651
    arrows = []
9✔
652
    for loc in arrow_locs:
9✔
653
        n = np.searchsorted(s, s[-1] * loc)
9✔
654

655
        if dir == 1 and n == 0:
9✔
656
            # Move the arrow forward by one if it is at start of a segment
657
            n = 1
9✔
658

659
        # Place the head of the arrow at the desired location
660
        arrow_head = [x[n], y[n]]
9✔
661
        arrow_tail = [x[n - dir], y[n - dir]]
9✔
662

663
        p = mpl.patches.FancyArrowPatch(
9✔
664
            arrow_tail, arrow_head, transform=axes.transData, lw=0,
665
            **arrow_kw)
666
        axes.add_patch(p)
9✔
667
        arrows.append(p)
9✔
668
    return arrows
9✔
669

670

671
def _get_color_offset(ax, color_cycle=None):
9✔
672
    """Get color offset based on current lines.
673

674
    This function determines that the current offset is for the next color
675
    to use based on current colors in a plot.
676

677
    Parameters
678
    ----------
679
    ax : matplotlib.axes.Axes
680
        Axes containing already plotted lines.
681
    color_cycle : list of matplotlib color specs, optional
682
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
683
        color cycle.
684

685
    Returns
686
    -------
687
    color_offset : matplotlib color spec
688
        Starting color for next line to be drawn.
689
    color_cycle : list of matplotlib color specs
690
        Color cycle used to determine colors.
691

692
    """
693
    if color_cycle is None:
9✔
694
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
695

696
    color_offset = 0
9✔
697
    if len(ax.lines) > 0:
9✔
698
        last_color = ax.lines[-1].get_color()
9✔
699
        if last_color in color_cycle:
9✔
700
            color_offset = color_cycle.index(last_color) + 1
9✔
701

702
    return color_offset % len(color_cycle), color_cycle
9✔
703

704

705
def _get_color(
9✔
706
        colorspec, offset=None, fmt=None, ax=None, lines=None,
707
        color_cycle=None):
708
    """Get color to use for plotting line.
709

710
    This function returns the color to be used for the line to be drawn (or
711
    None if the detault color cycle for the axes should be used).
712

713
    Parameters
714
    ----------
715
    colorspec : matplotlib color specification
716
        User-specified color (or None).
717
    offset : int, optional
718
        Offset into the color cycle (for multi-trace plots).
719
    fmt : str, optional
720
        Format string passed to plotting command.
721
    ax : matplotlib.axes.Axes, optional
722
        Axes containing already plotted lines.
723
    lines : list of matplotlib.lines.Line2D, optional
724
        List of plotted lines.  If not given, use ax.get_lines().
725
    color_cycle : list of matplotlib color specs, optional
726
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
727
        color cycle.
728

729
    Returns
730
    -------
731
    color : matplotlib color spec
732
        Color to use for this line (or None for matplotlib default).
733

734
    """
735
    # See if the color was explicitly specified by the user
736
    if isinstance(colorspec, dict):
9✔
737
        if 'color' in colorspec:
9✔
738
            return colorspec.pop('color')
9✔
739
    elif fmt is not None and \
9✔
740
         [isinstance(arg, str) and
741
          any([c in arg for c in "bgrcmykw#"]) for arg in fmt]:
742
        return None             # *fmt will set the color
9✔
743
    elif colorspec != None:
9✔
744
        return colorspec
9✔
745

746
    # Figure out what color cycle to use, if not given by caller
747
    if color_cycle == None:
9✔
748
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
749

750
    # Find the lines that we should pay attention to
751
    if lines is None and ax is not None:
9✔
752
        lines = ax.lines
9✔
753

754
    # If we were passed a set of lines, try to increment color from previous
755
    if offset is not None:
9✔
756
        return color_cycle[offset]
9✔
757
    elif lines is not None:
9✔
758
        color_offset = 0
9✔
759
        if len(ax.lines) > 0:
9✔
760
            last_color = ax.lines[-1].get_color()
9✔
761
            if last_color in color_cycle:
9✔
762
                color_offset = color_cycle.index(last_color) + 1
9✔
763
        color_offset = color_offset % len(color_cycle)
9✔
764
        return color_cycle[color_offset]
9✔
765
    else:
766
        return None
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc