• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 10312443515

09 Aug 2024 02:07AM UTC coverage: 94.694% (+0.04%) from 94.65%
10312443515

push

github

web-flow
Merge pull request #1034 from murrayrm/ctrlplot_updates-27Jun2024

Control plot refactoring for consistent functionality

9137 of 9649 relevant lines covered (94.69%)

8.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.97
control/ctrlplot.py
1
# ctrlplot.py - utility functions for plotting
2
# Richard M. Murray, 14 Jun 2024
3
#
4
# Collection of functions that are used by various plotting functions.
5

6
# Code pattern for control system plotting functions:
7
#
8
# def name_plot(sysdata, *fmt, plot=None, **kwargs):
9
#     # Process keywords and set defaults
10
#     ax = kwargs.pop('ax', None)
11
#     color = kwargs.pop('color', None)
12
#     label = kwargs.pop('label', None)
13
#     rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
14
#
15
#     # Make sure all keyword arguments were processed (if not checked later)
16
#     if kwargs:
17
#         raise TypeError("unrecognized keywords: ", str(kwargs))
18
#
19
#     # Process the data (including generating responses for systems)
20
#     sysdata = list(sysdata)
21
#     if any([isinstance(sys, InputOutputSystem) for sys in sysdata]):
22
#         data = name_response(sysdata)
23
#     nrows = max([data.noutputs for data in sysdata])
24
#     ncols = max([data.ninputs for data in sysdata])
25
#
26
#     # Legacy processing of plot keyword
27
#     if plot is False:
28
#         return data.x, data.y
29
#
30
#     # Figure out the shape of the plot and find/create axes
31
#     fig, ax_array = _process_ax_keyword(ax, (nrows, ncols), rcParams)
32
#     legend_loc, legend_map, show_legend = _process_legend_keywords(
33
#         kwargs, (nrows, ncols), 'center right')
34
#
35
#     # Customize axes (curvilinear grids, shared axes, etc)
36
#
37
#     # Plot the data
38
#     lines = np.full(ax_array.shape, [])
39
#     line_labels = _process_line_labels(label, ntraces, nrows, ncols)
40
#     color_offset, color_cycle = _get_color_offset(ax)
41
#     for i, j in itertools.product(range(nrows), range(ncols)):
42
#         ax = ax_array[i, j]
43
#         for k in range(ntraces):
44
#             if color is None:
45
#                 color = _get_color(
46
#                     color, fmt=fmt, offset=k, color_cycle=color_cycle)
47
#             label = line_labels[k, i, j]
48
#             lines[i, j] += ax.plot(data.x, data.y, color=color, label=label)
49
#
50
#     # Customize and label the axes
51
#     for i, j in itertools.product(range(nrows), range(ncols)):
52
#         ax_array[i, j].set_xlabel("x label")
53
#         ax_array[i, j].set_ylabel("y label")
54
#
55
#     # Create legends
56
#     if show_legend != False:
57
#         legend_array = np.full(ax_array.shape, None, dtype=object)
58
#         for i, j in itertools.product(range(nrows), range(ncols)):
59
#             if legend_map[i, j] is not None:
60
#                 lines = ax_array[i, j].get_lines()
61
#                 labels = _make_legend_labels(lines)
62
#                 if len(labels) > 1:
63
#                     legend_array[i, j] = ax.legend(
64
#                         lines, labels, loc=legend_map[i, j])
65
#     else:
66
#         legend_array = None
67
#
68
#     # Update the plot title (only if ax was not given)
69
#     sysnames = [response.sysname for response in data]
70
#     if ax is None and title is None:
71
#         title = "Name plot for " + ", ".join(sysnames)
72
#         _update_plot_title(title, fig, rcParams=rcParams)
73
#     elif ax == None:
74
#         _update_plot_title(title, fig, rcParams=rcParams, use_existing=False)
75
#
76
#     # Legacy processing of plot keyword
77
#     if plot is True:
78
#         return data
79
#
80
#     return ControlPlot(lines, ax_array, fig, legend=legend_map)
81

82
import itertools
9✔
83
import warnings
9✔
84
from os.path import commonprefix
9✔
85

86
import matplotlib as mpl
9✔
87
import matplotlib.pyplot as plt
9✔
88
import numpy as np
9✔
89

90
from . import config
9✔
91

92
__all__ = [
9✔
93
    'ControlPlot', 'suptitle', 'get_plot_axes', 'pole_zero_subplots',
94
    'rcParams', 'reset_rcParams']
95

96
#
97
# Style parameters
98
#
99

100
rcParams_default = {
9✔
101
    'axes.labelsize': 'small',
102
    'axes.titlesize': 'small',
103
    'figure.titlesize': 'medium',
104
    'legend.fontsize': 'x-small',
105
    'xtick.labelsize': 'small',
106
    'ytick.labelsize': 'small',
107
}
108
_ctrlplot_rcParams = rcParams_default.copy()    # provide access inside module
9✔
109
rcParams = _ctrlplot_rcParams                   # provide access outside module
9✔
110

111
_ctrlplot_defaults = {'ctrlplot.rcParams': _ctrlplot_rcParams}
9✔
112

113

114
#
115
# Control figure
116
#
117

118
class ControlPlot(object):
9✔
119
    """A class for returning control figures.
120

121
    This class is used as the return type for control plotting functions.
122
    It contains the information required to access portions of the plot
123
    that the user might want to adjust, as well as providing methods to
124
    modify some of the properties of the plot.
125

126
    A control figure consists of a :class:`matplotlib.figure.Figure` with
127
    an array of :class:`matplotlib.axes.Axes`.  Each axes in the figure has
128
    a number of lines that represent the data for the plot.  There may also
129
    be a legend present in one or more of the axes.
130

131
    Attributes
132
    ----------
133
    lines : array of list of :class:`matplotlib:Line2D`
134
        Array of Line2D objects for each line in the plot.  Generally, the
135
        shape of the array matches the subplots shape and the value of the
136
        array is a list of Line2D objects in that subplot.  Some plotting
137
        functions will return variants of this structure, as described in
138
        the individual documentation for the functions.
139
    axes : 2D array of :class:`matplotlib:Axes`
140
        Array of Axes objects for each subplot in the plot.
141
    figure : :class:`matplotlib:Figure`
142
        Figure on which the Axes are drawn.
143
    legend : :class:`matplotlib:.legend.Legend` (instance or ndarray)
144
        Legend object(s) for the plot.  If more than one legend is
145
        included, this will be an array with each entry being either None
146
        (for no legend) or a legend object.
147

148
    """
149
    def __init__(self, lines, axes=None, figure=None, legend=None):
9✔
150
        self.lines = lines
9✔
151
        if axes is None:
9✔
152
            _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
153
            axes = _get_axes(lines)
9✔
154
        self.axes = np.atleast_2d(axes)
9✔
155
        if figure is None:
9✔
156
            figure = self.axes[0, 0].figure
9✔
157
        self.figure = figure
9✔
158
        self.legend = legend
9✔
159

160
    # Implement methods and properties to allow legacy interface (np.array)
161
    __iter__ = lambda self: self.lines
9✔
162
    __len__ = lambda self: len(self.lines)
9✔
163
    def __getitem__(self, item):
9✔
164
        warnings.warn(
9✔
165
            "return of Line2D objects from plot function is deprecated in "
166
            "favor of ControlPlot; use out.lines to access Line2D objects",
167
            category=FutureWarning)
168
        return self.lines[item]
9✔
169
    def __setitem__(self, item, val):
9✔
170
        self.lines[item] = val
×
171
    shape = property(lambda self: self.lines.shape, None)
9✔
172
    def reshape(self, *args):
9✔
173
        return self.lines.reshape(*args)
9✔
174

175
    def set_plot_title(self, title, frame='axes'):
9✔
176
        """Set the title for a control plot.
177

178
        This is a wrapper for the matplotlib `suptitle` function, but by
179
        setting ``frame`` to 'axes' (default) then the title is centered on
180
        the midpoint of the axes in the figure, rather than the center of
181
        the figure.  This usually looks better (particularly with
182
        multi-panel plots), though it takes longer to render.
183

184
        Parameters
185
        ----------
186
        title : str
187
            Title text.
188
        fig : Figure, optional
189
            Matplotlib figure.  Defaults to current figure.
190
        frame : str, optional
191
            Coordinate frame to use for centering: 'axes' (default) or 'figure'.
192
        **kwargs : :func:`matplotlib.pyplot.suptitle` keywords, optional
193
            Additional keywords (passed to matplotlib).
194

195
        """
196
        _update_plot_title(
9✔
197
            title, fig=self.figure, frame=frame, use_existing=False)
198

199
#
200
# User functions
201
#
202
# The functions below can be used by users to modify control plots or get
203
# information about them.
204
#
205

206
def suptitle(
9✔
207
        title, fig=None, frame='axes', **kwargs):
208
    """Add a centered title to a figure.
209

210
    This function is deprecated.  Use :func:`ControlPlot.set_plot_title`.
211

212
    """
213
    warnings.warn(
9✔
214
        "suptitle is deprecated; use cplt.set_plot_title", FutureWarning)
215
    _update_plot_title(
9✔
216
        title, fig=fig, frame=frame, use_existing=False, **kwargs)
217

218

219
# Create vectorized function to find axes from lines
220
def get_plot_axes(line_array):
9✔
221
    """Get a list of axes from an array of lines.
222

223
    This function can be used to return the set of axes corresponding
224
    to the line array that is returned by `time_response_plot`.  This
225
    is useful for generating an axes array that can be passed to
226
    subsequent plotting calls.
227

228
    Parameters
229
    ----------
230
    line_array : array of list of Line2D
231
        A 2D array with elements corresponding to a list of lines appearing
232
        in an axes, matching the return type of a time response data plot.
233

234
    Returns
235
    -------
236
    axes_array : array of list of Axes
237
        A 2D array with elements corresponding to the Axes associated with
238
        the lines in `line_array`.
239

240
    Notes
241
    -----
242
    Only the first element of each array entry is used to determine the axes.
243

244
    """
245
    warnings.warn("get_plot_axes is deprecated; use cplt.axes", FutureWarning)
9✔
246
    _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
247
    if isinstance(line_array, ControlPlot):
9✔
248
        return _get_axes(line_array.lines)
9✔
249
    else:
250
        return _get_axes(line_array)
9✔
251

252

253
def pole_zero_subplots(
9✔
254
        nrows, ncols, grid=None, dt=None, fig=None, scaling=None,
255
        rcParams=None):
256
    """Create axes for pole/zero plot.
257

258
    Parameters
259
    ----------
260
    nrows, ncols : int
261
        Number of rows and columns.
262
    grid : True, False, or 'empty', optional
263
        Grid style to use.  Can also be a list, in which case each subplot
264
        will have a different style (columns then rows).
265
    dt : timebase, option
266
        Timebase for each subplot (or a list of timebases).
267
    scaling : 'auto', 'equal', or None
268
        Scaling to apply to the subplots.
269
    fig : :class:`matplotlib.figure.Figure`
270
        Figure to use for creating subplots.
271

272
    Returns
273
    -------
274
    ax_array : array
275
        2D array of axes
276

277
    """
278
    from .grid import nogrid, sgrid, zgrid
9✔
279
    from .iosys import isctime
9✔
280

281
    if fig is None:
9✔
282
        fig = plt.gcf()
9✔
283
    rcParams = config._get_param('ctrlplot', 'rcParams', rcParams)
9✔
284

285
    if not isinstance(grid, list):
9✔
286
        grid = [grid] * nrows * ncols
9✔
287
    if not isinstance(dt, list):
9✔
288
        dt = [dt] * nrows * ncols
9✔
289

290
    ax_array = np.full((nrows, ncols), None)
9✔
291
    index = 0
9✔
292
    with plt.rc_context(rcParams):
9✔
293
        for row, col in itertools.product(range(nrows), range(ncols)):
9✔
294
            match grid[index], isctime(dt=dt[index]):
9✔
295
                case 'empty', _:        # empty grid
9✔
296
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
297

298
                case True, True:        # continuous time grid
9✔
299
                    ax_array[row, col], _ = sgrid(
9✔
300
                        (nrows, ncols, index+1), scaling=scaling)
301

302
                case True, False:       # discrete time grid
9✔
303
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
304
                    zgrid(ax=ax_array[row, col], scaling=scaling)
9✔
305

306
                case False | None, _:   # no grid (just stability boundaries)
9✔
307
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
308
                    nogrid(
9✔
309
                        ax=ax_array[row, col], dt=dt[index], scaling=scaling)
310
            index += 1
9✔
311
    return ax_array
9✔
312

313

314
def reset_rcParams():
9✔
315
    """Reset rcParams to default values for control plots."""
316
    _ctrlplot_rcParams.update(rcParams_default)
9✔
317

318

319
#
320
# Utility functions
321
#
322
# These functions are used by plotting routines to provide a consistent way
323
# of processing and displaying information.
324
#
325

326
def _process_ax_keyword(
9✔
327
        axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False,
328
        create_axes=True):
329
    """Process ax keyword to plotting commands.
330

331
    This function processes the `ax` keyword to plotting commands.  If no
332
    ax keyword is passed, the current figure is checked to see if it has
333
    the correct shape.  If the shape matches the desired shape, then the
334
    current figure and axes are returned.  Otherwise a new figure is
335
    created with axes of the desired shape.
336

337
    If `create_axes` is False and a new/empty figure is returned, then axs
338
    is an array of the proper shape but None for each element.  This allows
339
    the calling function to do the actual axis creation (needed for
340
    curvilinear grids that use the AxisArtist module).
341

342
    Legacy behavior: some of the older plotting commands use a axes label
343
    to identify the proper axes for plotting.  This behavior is supported
344
    through the use of the label keyword, but will only work if shape ==
345
    (1, 1) and squeeze == True.
346

347
    """
348
    if axs is None:
9✔
349
        fig = plt.gcf()         # get current figure (or create new one)
9✔
350
        axs = fig.get_axes()
9✔
351

352
        # Check to see if axes are the right shape; if not, create new figure
353
        # Note: can't actually check the shape, just the total number of axes
354
        if len(axs) != np.prod(shape):
9✔
355
            with plt.rc_context(rcParams):
9✔
356
                if len(axs) != 0 and create_axes:
9✔
357
                    # Create a new figure
358
                    fig, axs = plt.subplots(*shape, squeeze=False)
9✔
359
                elif create_axes:
9✔
360
                    # Create new axes on (empty) figure
361
                    axs = fig.subplots(*shape, squeeze=False)
9✔
362
                else:
363
                    # Create an empty array and let user create axes
364
                    axs = np.full(shape, None)
9✔
365
            if create_axes:     # if not creating axes, leave these to caller
9✔
366
                fig.set_layout_engine('tight')
9✔
367
                fig.align_labels()
9✔
368

369
        else:
370
            # Use the existing axes, properly reshaped
371
            axs = np.asarray(axs).reshape(*shape)
9✔
372

373
            if clear_text:
9✔
374
                # Clear out any old text from the current figure
375
                for text in fig.texts:
9✔
376
                    text.set_visible(False)     # turn off the text
9✔
377
                    del text                    # get rid of it completely
9✔
378
    else:
379
        axs = np.atleast_1d(axs)
9✔
380
        try:
9✔
381
            axs = axs.reshape(shape)
9✔
382
        except ValueError:
9✔
383
            raise ValueError(
9✔
384
                "specified axes are not the right shape; "
385
                f"got {axs.shape} but expecting {shape}")
386
        fig = axs[0, 0].figure
9✔
387

388
    # Process the squeeze keyword
389
    if squeeze and shape == (1, 1):
9✔
390
        axs = axs[0, 0]         # Just return the single axes object
9✔
391
    elif squeeze:
9✔
392
        axs = axs.squeeze()
×
393

394
    return fig, axs
9✔
395

396

397
# Turn label keyword into array indexed by trace, output, input
398
# TODO: move to ctrlutil.py and update parameter names to reflect general use
399
def _process_line_labels(label, ntraces=1, ninputs=0, noutputs=0):
9✔
400
    if label is None:
9✔
401
        return None
9✔
402

403
    if isinstance(label, str):
9✔
404
        label = [label] * ntraces          # single label for all traces
9✔
405

406
    # Convert to an ndarray, if not done already
407
    try:
9✔
408
        line_labels = np.asarray(label)
9✔
409
    except ValueError:
×
410
        raise ValueError("label must be a string or array_like")
×
411

412
    # Turn the data into a 3D array of appropriate shape
413
    # TODO: allow more sophisticated broadcasting (and error checking)
414
    try:
9✔
415
        if ninputs > 0 and noutputs > 0:
9✔
416
            if line_labels.ndim == 1 and line_labels.size == ntraces:
9✔
417
                line_labels = line_labels.reshape(ntraces, 1, 1)
9✔
418
                line_labels = np.broadcast_to(
9✔
419
                    line_labels, (ntraces, ninputs, noutputs))
420
            else:
421
                line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
9✔
422
    except ValueError:
×
423
        if line_labels.shape[0] != ntraces:
×
424
            raise ValueError("number of labels must match number of traces")
×
425
        else:
426
            raise ValueError("labels must be given for each input/output pair")
×
427

428
    return line_labels
9✔
429

430

431
# Get labels for all lines in an axes
432
def _get_line_labels(ax, use_color=True):
9✔
433
    labels_colors, lines = [], []
9✔
434
    last_color, counter = None, 0       # label unknown systems
9✔
435
    for i, line in enumerate(ax.get_lines()):
9✔
436
        label = line.get_label()
9✔
437
        color = line.get_color()
9✔
438
        if use_color and label.startswith("Unknown"):
9✔
439
            label = f"Unknown-{counter}"
×
440
            if last_color != color:
×
441
                counter += 1
×
442
            last_color = color
×
443
        elif label[0] == '_':
9✔
444
            continue
9✔
445

446
        if (label, color) not in labels_colors:
9✔
447
            lines.append(line)
9✔
448
            labels_colors.append((label, color))
9✔
449

450
    return lines, [label for label, color in labels_colors]
9✔
451

452

453
def _process_legend_keywords(
9✔
454
        kwargs, shape=None, default_loc='center right'):
455
    legend_loc = kwargs.pop('legend_loc', None)
9✔
456
    if shape is None and 'legend_map' in kwargs:
9✔
457
        raise TypeError("unexpected keyword argument 'legend_map'")
9✔
458
    else:
459
        legend_map = kwargs.pop('legend_map', None)
9✔
460
    show_legend = kwargs.pop('show_legend', None)
9✔
461

462
    # If legend_loc or legend_map were given, always show the legend
463
    if legend_loc is False or legend_map is False:
9✔
464
        if show_legend is True:
9✔
465
            warnings.warn(
×
466
                "show_legend ignored; legend_loc or legend_map was given")
467
        show_legend = False
9✔
468
        legend_loc = legend_map = None
9✔
469
    elif legend_loc is not None or legend_map is not None:
9✔
470
        if show_legend is False:
9✔
471
            warnings.warn(
×
472
                "show_legend ignored; legend_loc or legend_map was given")
473
        show_legend = True
9✔
474

475
    if legend_loc is None:
9✔
476
        legend_loc = default_loc
9✔
477
    elif not isinstance(legend_loc, (int, str)):
9✔
478
        raise ValueError("legend_loc must be string or int")
×
479

480
    # Make sure the legend map is the right size
481
    if legend_map is not None:
9✔
482
        legend_map = np.atleast_2d(legend_map)
9✔
483
        if legend_map.shape != shape:
9✔
484
            raise ValueError("legend_map shape just match axes shape")
×
485

486
    return legend_loc, legend_map, show_legend
9✔
487

488

489
# Utility function to make legend labels
490
def _make_legend_labels(labels, ignore_common=False):
9✔
491
    if len(labels) == 1:
9✔
492
        return labels
9✔
493

494
    # Look for a common prefix (up to a space)
495
    common_prefix = commonprefix(labels)
9✔
496
    last_space = common_prefix.rfind(', ')
9✔
497
    if last_space < 0 or ignore_common:
9✔
498
        common_prefix = ''
9✔
499
    elif last_space > 0:
9✔
500
        common_prefix = common_prefix[:last_space + 2]
9✔
501
    prefix_len = len(common_prefix)
9✔
502

503
    # Look for a common suffix (up to a space)
504
    common_suffix = commonprefix(
9✔
505
        [label[::-1] for label in labels])[::-1]
506
    suffix_len = len(common_suffix)
9✔
507
    # Only chop things off after a comma or space
508
    while suffix_len > 0 and common_suffix[-suffix_len] != ',':
9✔
509
        suffix_len -= 1
9✔
510

511
    # Strip the labels of common information
512
    if suffix_len > 0 and not ignore_common:
9✔
513
        labels = [label[prefix_len:-suffix_len] for label in labels]
9✔
514
    else:
515
        labels = [label[prefix_len:] for label in labels]
9✔
516

517
    return labels
9✔
518

519

520
def _update_plot_title(
9✔
521
        title, fig=None, frame='axes', use_existing=True, **kwargs):
522
    if title is False or title is None:
9✔
523
        return
9✔
524
    if fig is None:
9✔
525
        fig = plt.gcf()
9✔
526
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
527

528
    if use_existing:
9✔
529
        # Get the current title, if it exists
530
        old_title = None if fig._suptitle is None else fig._suptitle._text
9✔
531

532
        if old_title is not None:
9✔
533
            # Find the common part of the titles
534
            common_prefix = commonprefix([old_title, title])
9✔
535

536
            # Back up to the last space
537
            last_space = common_prefix.rfind(' ')
9✔
538
            if last_space > 0:
9✔
539
                common_prefix = common_prefix[:last_space]
9✔
540
            common_len = len(common_prefix)
9✔
541

542
            # Add the new part of the title (usually the system name)
543
            if old_title[common_len:] != title[common_len:]:
9✔
544
                separator = ',' if len(common_prefix) > 0 else ';'
9✔
545
                title = old_title + separator + title[common_len:]
9✔
546

547
    if frame == 'figure':
9✔
548
        with plt.rc_context(rcParams):
9✔
549
            fig.suptitle(title, **kwargs)
9✔
550

551
    elif frame == 'axes':
9✔
552
        with plt.rc_context(rcParams):
9✔
553
            fig.suptitle(title, **kwargs)           # Place title in center
9✔
554
            plt.tight_layout()                      # Put everything into place
9✔
555
            xc, _ = _find_axes_center(fig, fig.get_axes())
9✔
556
            fig.suptitle(title, x=xc, **kwargs)     # Redraw title, centered
9✔
557

558
    else:
559
        raise ValueError(f"unknown frame '{frame}'")
9✔
560

561

562
def _find_axes_center(fig, axs):
9✔
563
    """Find the midpoint between axes in display coordinates.
564

565
    This function finds the middle of a plot as defined by a set of axes.
566

567
    """
568
    inv_transform = fig.transFigure.inverted()
9✔
569
    xlim = ylim = [1, 0]
9✔
570
    for ax in axs:
9✔
571
        ll = inv_transform.transform(ax.transAxes.transform((0, 0)))
9✔
572
        ur = inv_transform.transform(ax.transAxes.transform((1, 1)))
9✔
573

574
        xlim = [min(ll[0], xlim[0]), max(ur[0], xlim[1])]
9✔
575
        ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]
9✔
576

577
    return (np.sum(xlim)/2, np.sum(ylim)/2)
9✔
578

579

580
# Internal function to add arrows to a curve
581
def _add_arrows_to_line2D(
9✔
582
        axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
583
        arrowstyle='-|>', arrowsize=1, dir=1):
584
    """
585
    Add arrows to a matplotlib.lines.Line2D at selected locations.
586

587
    Parameters
588
    ----------
589
    axes: Axes object as returned by axes command (or gca)
590
    line: Line2D object as returned by plot command
591
    arrow_locs: list of locations where to insert arrows, % of total length
592
    arrowstyle: style of the arrow
593
    arrowsize: size of the arrow
594

595
    Returns
596
    -------
597
    arrows: list of arrows
598

599
    Based on https://stackoverflow.com/questions/26911898/
600

601
    """
602
    # Get the coordinates of the line, in plot coordinates
603
    if not isinstance(line, mpl.lines.Line2D):
9✔
604
        raise ValueError("expected a matplotlib.lines.Line2D object")
×
605
    x, y = line.get_xdata(), line.get_ydata()
9✔
606

607
    # Determine the arrow properties
608
    arrow_kw = {"arrowstyle": arrowstyle}
9✔
609

610
    color = line.get_color()
9✔
611
    use_multicolor_lines = isinstance(color, np.ndarray)
9✔
612
    if use_multicolor_lines:
9✔
613
        raise NotImplementedError("multicolor lines not supported")
614
    else:
615
        arrow_kw['color'] = color
9✔
616

617
    linewidth = line.get_linewidth()
9✔
618
    if isinstance(linewidth, np.ndarray):
9✔
619
        raise NotImplementedError("multiwidth lines not supported")
620
    else:
621
        arrow_kw['linewidth'] = linewidth
9✔
622

623
    # Figure out the size of the axes (length of diagonal)
624
    xlim, ylim = axes.get_xlim(), axes.get_ylim()
9✔
625
    ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
9✔
626
    diag = np.linalg.norm(ul - lr)
9✔
627

628
    # Compute the arc length along the curve
629
    s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
9✔
630

631
    # Truncate the number of arrows if the curve is short
632
    # TODO: figure out a smarter way to do this
633
    frac = min(s[-1] / diag, 1)
9✔
634
    if len(arrow_locs) and frac < 0.05:
9✔
635
        arrow_locs = []         # too short; no arrows at all
9✔
636
    elif len(arrow_locs) and frac < 0.2:
9✔
637
        arrow_locs = [0.5]      # single arrow in the middle
9✔
638

639
    # Plot the arrows (and return list if patches)
640
    arrows = []
9✔
641
    for loc in arrow_locs:
9✔
642
        n = np.searchsorted(s, s[-1] * loc)
9✔
643

644
        if dir == 1 and n == 0:
9✔
645
            # Move the arrow forward by one if it is at start of a segment
646
            n = 1
9✔
647

648
        # Place the head of the arrow at the desired location
649
        arrow_head = [x[n], y[n]]
9✔
650
        arrow_tail = [x[n - dir], y[n - dir]]
9✔
651

652
        p = mpl.patches.FancyArrowPatch(
9✔
653
            arrow_tail, arrow_head, transform=axes.transData, lw=0,
654
            **arrow_kw)
655
        axes.add_patch(p)
9✔
656
        arrows.append(p)
9✔
657
    return arrows
9✔
658

659

660
def _get_color_offset(ax, color_cycle=None):
9✔
661
    """Get color offset based on current lines.
662

663
    This function determines that the current offset is for the next color
664
    to use based on current colors in a plot.
665

666
    Parameters
667
    ----------
668
    ax : matplotlib.axes.Axes
669
        Axes containing already plotted lines.
670
    color_cycle : list of matplotlib color specs, optional
671
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
672
        color cycle.
673

674
    Returns
675
    -------
676
    color_offset : matplotlib color spec
677
        Starting color for next line to be drawn.
678
    color_cycle : list of matplotlib color specs
679
        Color cycle used to determine colors.
680

681
    """
682
    if color_cycle is None:
9✔
683
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
684

685
    color_offset = 0
9✔
686
    if len(ax.lines) > 0:
9✔
687
        last_color = ax.lines[-1].get_color()
9✔
688
        if last_color in color_cycle:
9✔
689
            color_offset = color_cycle.index(last_color) + 1
9✔
690

691
    return color_offset % len(color_cycle), color_cycle
9✔
692

693

694
def _get_color(
9✔
695
        colorspec, offset=None, fmt=None, ax=None, lines=None,
696
        color_cycle=None):
697
    """Get color to use for plotting line.
698

699
    This function returns the color to be used for the line to be drawn (or
700
    None if the detault color cycle for the axes should be used).
701

702
    Parameters
703
    ----------
704
    colorspec : matplotlib color specification
705
        User-specified color (or None).
706
    offset : int, optional
707
        Offset into the color cycle (for multi-trace plots).
708
    fmt : str, optional
709
        Format string passed to plotting command.
710
    ax : matplotlib.axes.Axes, optional
711
        Axes containing already plotted lines.
712
    lines : list of matplotlib.lines.Line2D, optional
713
        List of plotted lines.  If not given, use ax.get_lines().
714
    color_cycle : list of matplotlib color specs, optional
715
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
716
        color cycle.
717

718
    Returns
719
    -------
720
    color : matplotlib color spec
721
        Color to use for this line (or None for matplotlib default).
722

723
    """
724
    # See if the color was explicitly specified by the user
725
    if isinstance(colorspec, dict):
9✔
726
        if 'color' in colorspec:
9✔
727
            return colorspec.pop('color')
9✔
728
    elif fmt is not None and \
9✔
729
         [isinstance(arg, str) and
730
          any([c in arg for c in "bgrcmykw#"]) for arg in fmt]:
731
        return None             # *fmt will set the color
9✔
732
    elif colorspec != None:
9✔
733
        return colorspec
9✔
734

735
    # Figure out what color cycle to use, if not given by caller
736
    if color_cycle == None:
9✔
737
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
738

739
    # Find the lines that we should pay attention to
740
    if lines is None and ax is not None:
9✔
741
        lines = ax.lines
9✔
742

743
    # If we were passed a set of lines, try to increment color from previous
744
    if offset is not None:
9✔
745
        return color_cycle[offset]
9✔
746
    elif lines is not None:
9✔
747
        color_offset = 0
9✔
748
        if len(ax.lines) > 0:
9✔
749
            last_color = ax.lines[-1].get_color()
9✔
750
            if last_color in color_cycle:
9✔
751
                color_offset = color_cycle.index(last_color) + 1
9✔
752
        color_offset = color_offset % len(color_cycle)
9✔
753
        return color_cycle[color_offset]
9✔
754
    else:
755
        return None
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc