• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 10370763703

13 Aug 2024 01:32PM UTC coverage: 94.693% (-0.001%) from 94.694%
10370763703

push

github

web-flow
Merge pull request #1038 from murrayrm/doc-comment_fixes-11May2024

Documentation updates and docstring unit tests

9136 of 9648 relevant lines covered (94.69%)

8.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.97
control/ctrlplot.py
1
# ctrlplot.py - utility functions for plotting
2
# Richard M. Murray, 14 Jun 2024
3
#
4
# Collection of functions that are used by various plotting functions.
5

6
# Code pattern for control system plotting functions:
7
#
8
# def name_plot(sysdata, *fmt, plot=None, **kwargs):
9
#     # Process keywords and set defaults
10
#     ax = kwargs.pop('ax', None)
11
#     color = kwargs.pop('color', None)
12
#     label = kwargs.pop('label', None)
13
#     rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
14
#
15
#     # Make sure all keyword arguments were processed (if not checked later)
16
#     if kwargs:
17
#         raise TypeError("unrecognized keywords: ", str(kwargs))
18
#
19
#     # Process the data (including generating responses for systems)
20
#     sysdata = list(sysdata)
21
#     if any([isinstance(sys, InputOutputSystem) for sys in sysdata]):
22
#         data = name_response(sysdata)
23
#     nrows = max([data.noutputs for data in sysdata])
24
#     ncols = max([data.ninputs for data in sysdata])
25
#
26
#     # Legacy processing of plot keyword
27
#     if plot is False:
28
#         return data.x, data.y
29
#
30
#     # Figure out the shape of the plot and find/create axes
31
#     fig, ax_array = _process_ax_keyword(ax, (nrows, ncols), rcParams)
32
#     legend_loc, legend_map, show_legend = _process_legend_keywords(
33
#         kwargs, (nrows, ncols), 'center right')
34
#
35
#     # Customize axes (curvilinear grids, shared axes, etc)
36
#
37
#     # Plot the data
38
#     lines = np.full(ax_array.shape, [])
39
#     line_labels = _process_line_labels(label, ntraces, nrows, ncols)
40
#     color_offset, color_cycle = _get_color_offset(ax)
41
#     for i, j in itertools.product(range(nrows), range(ncols)):
42
#         ax = ax_array[i, j]
43
#         for k in range(ntraces):
44
#             if color is None:
45
#                 color = _get_color(
46
#                     color, fmt=fmt, offset=k, color_cycle=color_cycle)
47
#             label = line_labels[k, i, j]
48
#             lines[i, j] += ax.plot(data.x, data.y, color=color, label=label)
49
#
50
#     # Customize and label the axes
51
#     for i, j in itertools.product(range(nrows), range(ncols)):
52
#         ax_array[i, j].set_xlabel("x label")
53
#         ax_array[i, j].set_ylabel("y label")
54
#
55
#     # Create legends
56
#     if show_legend != False:
57
#         legend_array = np.full(ax_array.shape, None, dtype=object)
58
#         for i, j in itertools.product(range(nrows), range(ncols)):
59
#             if legend_map[i, j] is not None:
60
#                 lines = ax_array[i, j].get_lines()
61
#                 labels = _make_legend_labels(lines)
62
#                 if len(labels) > 1:
63
#                     legend_array[i, j] = ax.legend(
64
#                         lines, labels, loc=legend_map[i, j])
65
#     else:
66
#         legend_array = None
67
#
68
#     # Update the plot title (only if ax was not given)
69
#     sysnames = [response.sysname for response in data]
70
#     if ax is None and title is None:
71
#         title = "Name plot for " + ", ".join(sysnames)
72
#         _update_plot_title(title, fig, rcParams=rcParams)
73
#     elif ax == None:
74
#         _update_plot_title(title, fig, rcParams=rcParams, use_existing=False)
75
#
76
#     # Legacy processing of plot keyword
77
#     if plot is True:
78
#         return data
79
#
80
#     return ControlPlot(lines, ax_array, fig, legend=legend_map)
81

82
import itertools
9✔
83
import warnings
9✔
84
from os.path import commonprefix
9✔
85

86
import matplotlib as mpl
9✔
87
import matplotlib.pyplot as plt
9✔
88
import numpy as np
9✔
89

90
from . import config
9✔
91

92
__all__ = [
9✔
93
    'ControlPlot', 'suptitle', 'get_plot_axes', 'pole_zero_subplots',
94
    'rcParams', 'reset_rcParams']
95

96
#
97
# Style parameters
98
#
99

100
rcParams_default = {
9✔
101
    'axes.labelsize': 'small',
102
    'axes.titlesize': 'small',
103
    'figure.titlesize': 'medium',
104
    'legend.fontsize': 'x-small',
105
    'xtick.labelsize': 'small',
106
    'ytick.labelsize': 'small',
107
}
108
_ctrlplot_rcParams = rcParams_default.copy()    # provide access inside module
9✔
109
rcParams = _ctrlplot_rcParams                   # provide access outside module
9✔
110

111
_ctrlplot_defaults = {'ctrlplot.rcParams': _ctrlplot_rcParams}
9✔
112

113

114
#
115
# Control figure
116
#
117

118
class ControlPlot(object):
9✔
119
    """A class for returning control figures.
120

121
    This class is used as the return type for control plotting functions.
122
    It contains the information required to access portions of the plot
123
    that the user might want to adjust, as well as providing methods to
124
    modify some of the properties of the plot.
125

126
    A control figure consists of a :class:`matplotlib.figure.Figure` with
127
    an array of :class:`matplotlib.axes.Axes`.  Each axes in the figure has
128
    a number of lines that represent the data for the plot.  There may also
129
    be a legend present in one or more of the axes.
130

131
    Attributes
132
    ----------
133
    lines : array of list of :class:`matplotlib:Line2D`
134
        Array of Line2D objects for each line in the plot.  Generally, the
135
        shape of the array matches the subplots shape and the value of the
136
        array is a list of Line2D objects in that subplot.  Some plotting
137
        functions will return variants of this structure, as described in
138
        the individual documentation for the functions.
139
    axes : 2D array of :class:`matplotlib:Axes`
140
        Array of Axes objects for each subplot in the plot.
141
    figure : :class:`matplotlib:Figure`
142
        Figure on which the Axes are drawn.
143
    legend : :class:`matplotlib:.legend.Legend` (instance or ndarray)
144
        Legend object(s) for the plot.  If more than one legend is
145
        included, this will be an array with each entry being either None
146
        (for no legend) or a legend object.
147

148
    """
149
    def __init__(self, lines, axes=None, figure=None, legend=None):
9✔
150
        self.lines = lines
9✔
151
        if axes is None:
9✔
152
            _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
153
            axes = _get_axes(lines)
9✔
154
        self.axes = np.atleast_2d(axes)
9✔
155
        if figure is None:
9✔
156
            figure = self.axes[0, 0].figure
9✔
157
        self.figure = figure
9✔
158
        self.legend = legend
9✔
159

160
    # Implement methods and properties to allow legacy interface (np.array)
161
    __iter__ = lambda self: self.lines
9✔
162
    __len__ = lambda self: len(self.lines)
9✔
163
    def __getitem__(self, item):
9✔
164
        warnings.warn(
9✔
165
            "return of Line2D objects from plot function is deprecated in "
166
            "favor of ControlPlot; use out.lines to access Line2D objects",
167
            category=FutureWarning)
168
        return self.lines[item]
9✔
169
    def __setitem__(self, item, val):
9✔
170
        self.lines[item] = val
×
171
    shape = property(lambda self: self.lines.shape, None)
9✔
172
    def reshape(self, *args):
9✔
173
        return self.lines.reshape(*args)
9✔
174

175
    def set_plot_title(self, title, frame='axes'):
9✔
176
        """Set the title for a control plot.
177

178
        This is a wrapper for the matplotlib `suptitle` function, but by
179
        setting ``frame`` to 'axes' (default) then the title is centered on
180
        the midpoint of the axes in the figure, rather than the center of
181
        the figure.  This usually looks better (particularly with
182
        multi-panel plots), though it takes longer to render.
183

184
        Parameters
185
        ----------
186
        title : str
187
            Title text.
188
        fig : Figure, optional
189
            Matplotlib figure.  Defaults to current figure.
190
        frame : str, optional
191
            Coordinate frame to use for centering: 'axes' (default) or 'figure'.
192
        **kwargs : :func:`matplotlib.pyplot.suptitle` keywords, optional
193
            Additional keywords (passed to matplotlib).
194

195
        """
196
        _update_plot_title(
9✔
197
            title, fig=self.figure, frame=frame, use_existing=False)
198

199
#
200
# User functions
201
#
202
# The functions below can be used by users to modify control plots or get
203
# information about them.
204
#
205

206
def suptitle(
9✔
207
        title, fig=None, frame='axes', **kwargs):
208
    """Add a centered title to a figure.
209

210
    .. deprecated:: 0.10.1
211
        Use :func:`ControlPlot.set_plot_title`.
212

213
    """
214
    warnings.warn(
9✔
215
        "suptitle() is deprecated; use cplt.set_plot_title()", FutureWarning)
216
    _update_plot_title(
9✔
217
        title, fig=fig, frame=frame, use_existing=False, **kwargs)
218

219

220
# Create vectorized function to find axes from lines
221
def get_plot_axes(line_array):
9✔
222
    """Get a list of axes from an array of lines.
223

224
    .. deprecated:: 0.10.1
225
        This function will be removed in a future version of python-control.
226
        Use `cplt.axes` to obtain axes for an instance of :class:`ControlPlot`.
227

228
    This function can be used to return the set of axes corresponding
229
    to the line array that is returned by `time_response_plot`.  This
230
    is useful for generating an axes array that can be passed to
231
    subsequent plotting calls.
232

233
    Parameters
234
    ----------
235
    line_array : array of list of Line2D
236
        A 2D array with elements corresponding to a list of lines appearing
237
        in an axes, matching the return type of a time response data plot.
238

239
    Returns
240
    -------
241
    axes_array : array of list of Axes
242
        A 2D array with elements corresponding to the Axes associated with
243
        the lines in `line_array`.
244

245
    Notes
246
    -----
247
    Only the first element of each array entry is used to determine the axes.
248

249
    """
250
    warnings.warn(
9✔
251
        "get_plot_axes() is deprecated; use cplt.axes()", FutureWarning)
252
    _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
253
    if isinstance(line_array, ControlPlot):
9✔
254
        return _get_axes(line_array.lines)
9✔
255
    else:
256
        return _get_axes(line_array)
9✔
257

258

259
def pole_zero_subplots(
9✔
260
        nrows, ncols, grid=None, dt=None, fig=None, scaling=None,
261
        rcParams=None):
262
    """Create axes for pole/zero plot.
263

264
    Parameters
265
    ----------
266
    nrows, ncols : int
267
        Number of rows and columns.
268
    grid : True, False, or 'empty', optional
269
        Grid style to use.  Can also be a list, in which case each subplot
270
        will have a different style (columns then rows).
271
    dt : timebase, option
272
        Timebase for each subplot (or a list of timebases).
273
    scaling : 'auto', 'equal', or None
274
        Scaling to apply to the subplots.
275
    fig : :class:`matplotlib.figure.Figure`
276
        Figure to use for creating subplots.
277
    rcParams : dict
278
        Override the default parameters used for generating plots.
279
        Default is set up config.default['ctrlplot.rcParams'].
280

281
    Returns
282
    -------
283
    ax_array : array
284
        2D array of axes
285

286
    """
287
    from .grid import nogrid, sgrid, zgrid
9✔
288
    from .iosys import isctime
9✔
289

290
    if fig is None:
9✔
291
        fig = plt.gcf()
9✔
292
    rcParams = config._get_param('ctrlplot', 'rcParams', rcParams)
9✔
293

294
    if not isinstance(grid, list):
9✔
295
        grid = [grid] * nrows * ncols
9✔
296
    if not isinstance(dt, list):
9✔
297
        dt = [dt] * nrows * ncols
9✔
298

299
    ax_array = np.full((nrows, ncols), None)
9✔
300
    index = 0
9✔
301
    with plt.rc_context(rcParams):
9✔
302
        for row, col in itertools.product(range(nrows), range(ncols)):
9✔
303
            match grid[index], isctime(dt=dt[index]):
9✔
304
                case 'empty', _:        # empty grid
9✔
305
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
306

307
                case True, True:        # continuous time grid
9✔
308
                    ax_array[row, col], _ = sgrid(
9✔
309
                        (nrows, ncols, index+1), scaling=scaling)
310

311
                case True, False:       # discrete time grid
9✔
312
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
313
                    zgrid(ax=ax_array[row, col], scaling=scaling)
9✔
314

315
                case False | None, _:   # no grid (just stability boundaries)
9✔
316
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
317
                    nogrid(
9✔
318
                        ax=ax_array[row, col], dt=dt[index], scaling=scaling)
319
            index += 1
9✔
320
    return ax_array
9✔
321

322

323
def reset_rcParams():
9✔
324
    """Reset rcParams to default values for control plots."""
325
    _ctrlplot_rcParams.update(rcParams_default)
9✔
326

327

328
#
329
# Utility functions
330
#
331
# These functions are used by plotting routines to provide a consistent way
332
# of processing and displaying information.
333
#
334

335
def _process_ax_keyword(
9✔
336
        axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False,
337
        create_axes=True):
338
    """Process ax keyword to plotting commands.
339

340
    This function processes the `ax` keyword to plotting commands.  If no
341
    ax keyword is passed, the current figure is checked to see if it has
342
    the correct shape.  If the shape matches the desired shape, then the
343
    current figure and axes are returned.  Otherwise a new figure is
344
    created with axes of the desired shape.
345

346
    If `create_axes` is False and a new/empty figure is returned, then axs
347
    is an array of the proper shape but None for each element.  This allows
348
    the calling function to do the actual axis creation (needed for
349
    curvilinear grids that use the AxisArtist module).
350

351
    Legacy behavior: some of the older plotting commands use a axes label
352
    to identify the proper axes for plotting.  This behavior is supported
353
    through the use of the label keyword, but will only work if shape ==
354
    (1, 1) and squeeze == True.
355

356
    """
357
    if axs is None:
9✔
358
        fig = plt.gcf()         # get current figure (or create new one)
9✔
359
        axs = fig.get_axes()
9✔
360

361
        # Check to see if axes are the right shape; if not, create new figure
362
        # Note: can't actually check the shape, just the total number of axes
363
        if len(axs) != np.prod(shape):
9✔
364
            with plt.rc_context(rcParams):
9✔
365
                if len(axs) != 0 and create_axes:
9✔
366
                    # Create a new figure
367
                    fig, axs = plt.subplots(*shape, squeeze=False)
9✔
368
                elif create_axes:
9✔
369
                    # Create new axes on (empty) figure
370
                    axs = fig.subplots(*shape, squeeze=False)
9✔
371
                else:
372
                    # Create an empty array and let user create axes
373
                    axs = np.full(shape, None)
9✔
374
            if create_axes:     # if not creating axes, leave these to caller
9✔
375
                fig.set_layout_engine('tight')
9✔
376
                fig.align_labels()
9✔
377

378
        else:
379
            # Use the existing axes, properly reshaped
380
            axs = np.asarray(axs).reshape(*shape)
9✔
381

382
            if clear_text:
9✔
383
                # Clear out any old text from the current figure
384
                for text in fig.texts:
9✔
385
                    text.set_visible(False)     # turn off the text
9✔
386
                    del text                    # get rid of it completely
9✔
387
    else:
388
        axs = np.atleast_1d(axs)
9✔
389
        try:
9✔
390
            axs = axs.reshape(shape)
9✔
391
        except ValueError:
9✔
392
            raise ValueError(
9✔
393
                "specified axes are not the right shape; "
394
                f"got {axs.shape} but expecting {shape}")
395
        fig = axs[0, 0].figure
9✔
396

397
    # Process the squeeze keyword
398
    if squeeze and shape == (1, 1):
9✔
399
        axs = axs[0, 0]         # Just return the single axes object
9✔
400
    elif squeeze:
9✔
401
        axs = axs.squeeze()
×
402

403
    return fig, axs
9✔
404

405

406
# Turn label keyword into array indexed by trace, output, input
407
# TODO: move to ctrlutil.py and update parameter names to reflect general use
408
def _process_line_labels(label, ntraces=1, ninputs=0, noutputs=0):
9✔
409
    if label is None:
9✔
410
        return None
9✔
411

412
    if isinstance(label, str):
9✔
413
        label = [label] * ntraces          # single label for all traces
9✔
414

415
    # Convert to an ndarray, if not done already
416
    try:
9✔
417
        line_labels = np.asarray(label)
9✔
418
    except ValueError:
×
419
        raise ValueError("label must be a string or array_like")
×
420

421
    # Turn the data into a 3D array of appropriate shape
422
    # TODO: allow more sophisticated broadcasting (and error checking)
423
    try:
9✔
424
        if ninputs > 0 and noutputs > 0:
9✔
425
            if line_labels.ndim == 1 and line_labels.size == ntraces:
9✔
426
                line_labels = line_labels.reshape(ntraces, 1, 1)
9✔
427
                line_labels = np.broadcast_to(
9✔
428
                    line_labels, (ntraces, ninputs, noutputs))
429
            else:
430
                line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
9✔
431
    except ValueError:
×
432
        if line_labels.shape[0] != ntraces:
×
433
            raise ValueError("number of labels must match number of traces")
×
434
        else:
435
            raise ValueError("labels must be given for each input/output pair")
×
436

437
    return line_labels
9✔
438

439

440
# Get labels for all lines in an axes
441
def _get_line_labels(ax, use_color=True):
9✔
442
    labels_colors, lines = [], []
9✔
443
    last_color, counter = None, 0       # label unknown systems
9✔
444
    for i, line in enumerate(ax.get_lines()):
9✔
445
        label = line.get_label()
9✔
446
        color = line.get_color()
9✔
447
        if use_color and label.startswith("Unknown"):
9✔
448
            label = f"Unknown-{counter}"
×
449
            if last_color != color:
×
450
                counter += 1
×
451
            last_color = color
×
452
        elif label[0] == '_':
9✔
453
            continue
9✔
454

455
        if (label, color) not in labels_colors:
9✔
456
            lines.append(line)
9✔
457
            labels_colors.append((label, color))
9✔
458

459
    return lines, [label for label, color in labels_colors]
9✔
460

461

462
def _process_legend_keywords(
9✔
463
        kwargs, shape=None, default_loc='center right'):
464
    legend_loc = kwargs.pop('legend_loc', None)
9✔
465
    if shape is None and 'legend_map' in kwargs:
9✔
466
        raise TypeError("unexpected keyword argument 'legend_map'")
9✔
467
    else:
468
        legend_map = kwargs.pop('legend_map', None)
9✔
469
    show_legend = kwargs.pop('show_legend', None)
9✔
470

471
    # If legend_loc or legend_map were given, always show the legend
472
    if legend_loc is False or legend_map is False:
9✔
473
        if show_legend is True:
9✔
474
            warnings.warn(
×
475
                "show_legend ignored; legend_loc or legend_map was given")
476
        show_legend = False
9✔
477
        legend_loc = legend_map = None
9✔
478
    elif legend_loc is not None or legend_map is not None:
9✔
479
        if show_legend is False:
9✔
480
            warnings.warn(
×
481
                "show_legend ignored; legend_loc or legend_map was given")
482
        show_legend = True
9✔
483

484
    if legend_loc is None:
9✔
485
        legend_loc = default_loc
9✔
486
    elif not isinstance(legend_loc, (int, str)):
9✔
487
        raise ValueError("legend_loc must be string or int")
×
488

489
    # Make sure the legend map is the right size
490
    if legend_map is not None:
9✔
491
        legend_map = np.atleast_2d(legend_map)
9✔
492
        if legend_map.shape != shape:
9✔
493
            raise ValueError("legend_map shape just match axes shape")
×
494

495
    return legend_loc, legend_map, show_legend
9✔
496

497

498
# Utility function to make legend labels
499
def _make_legend_labels(labels, ignore_common=False):
9✔
500
    if len(labels) == 1:
9✔
501
        return labels
9✔
502

503
    # Look for a common prefix (up to a space)
504
    common_prefix = commonprefix(labels)
9✔
505
    last_space = common_prefix.rfind(', ')
9✔
506
    if last_space < 0 or ignore_common:
9✔
507
        common_prefix = ''
9✔
508
    elif last_space > 0:
9✔
509
        common_prefix = common_prefix[:last_space + 2]
9✔
510
    prefix_len = len(common_prefix)
9✔
511

512
    # Look for a common suffix (up to a space)
513
    common_suffix = commonprefix(
9✔
514
        [label[::-1] for label in labels])[::-1]
515
    suffix_len = len(common_suffix)
9✔
516
    # Only chop things off after a comma or space
517
    while suffix_len > 0 and common_suffix[-suffix_len] != ',':
9✔
518
        suffix_len -= 1
9✔
519

520
    # Strip the labels of common information
521
    if suffix_len > 0 and not ignore_common:
9✔
522
        labels = [label[prefix_len:-suffix_len] for label in labels]
9✔
523
    else:
524
        labels = [label[prefix_len:] for label in labels]
9✔
525

526
    return labels
9✔
527

528

529
def _update_plot_title(
9✔
530
        title, fig=None, frame='axes', use_existing=True, **kwargs):
531
    if title is False or title is None:
9✔
532
        return
9✔
533
    if fig is None:
9✔
534
        fig = plt.gcf()
9✔
535
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
536

537
    if use_existing:
9✔
538
        # Get the current title, if it exists
539
        old_title = None if fig._suptitle is None else fig._suptitle._text
9✔
540

541
        if old_title is not None:
9✔
542
            # Find the common part of the titles
543
            common_prefix = commonprefix([old_title, title])
9✔
544

545
            # Back up to the last space
546
            last_space = common_prefix.rfind(' ')
9✔
547
            if last_space > 0:
9✔
548
                common_prefix = common_prefix[:last_space]
9✔
549
            common_len = len(common_prefix)
9✔
550

551
            # Add the new part of the title (usually the system name)
552
            if old_title[common_len:] != title[common_len:]:
9✔
553
                separator = ',' if len(common_prefix) > 0 else ';'
9✔
554
                title = old_title + separator + title[common_len:]
9✔
555

556
    if frame == 'figure':
9✔
557
        with plt.rc_context(rcParams):
9✔
558
            fig.suptitle(title, **kwargs)
9✔
559

560
    elif frame == 'axes':
9✔
561
        with plt.rc_context(rcParams):
9✔
562
            fig.suptitle(title, **kwargs)           # Place title in center
9✔
563
            plt.tight_layout()                      # Put everything into place
9✔
564
            xc, _ = _find_axes_center(fig, fig.get_axes())
9✔
565
            fig.suptitle(title, x=xc, **kwargs)     # Redraw title, centered
9✔
566

567
    else:
568
        raise ValueError(f"unknown frame '{frame}'")
9✔
569

570

571
def _find_axes_center(fig, axs):
9✔
572
    """Find the midpoint between axes in display coordinates.
573

574
    This function finds the middle of a plot as defined by a set of axes.
575

576
    """
577
    inv_transform = fig.transFigure.inverted()
9✔
578
    xlim = ylim = [1, 0]
9✔
579
    for ax in axs:
9✔
580
        ll = inv_transform.transform(ax.transAxes.transform((0, 0)))
9✔
581
        ur = inv_transform.transform(ax.transAxes.transform((1, 1)))
9✔
582

583
        xlim = [min(ll[0], xlim[0]), max(ur[0], xlim[1])]
9✔
584
        ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]
9✔
585

586
    return (np.sum(xlim)/2, np.sum(ylim)/2)
9✔
587

588

589
# Internal function to add arrows to a curve
590
def _add_arrows_to_line2D(
9✔
591
        axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
592
        arrowstyle='-|>', arrowsize=1, dir=1):
593
    """
594
    Add arrows to a matplotlib.lines.Line2D at selected locations.
595

596
    Parameters
597
    ----------
598
    axes: Axes object as returned by axes command (or gca)
599
    line: Line2D object as returned by plot command
600
    arrow_locs: list of locations where to insert arrows, % of total length
601
    arrowstyle: style of the arrow
602
    arrowsize: size of the arrow
603

604
    Returns
605
    -------
606
    arrows: list of arrows
607

608
    Based on https://stackoverflow.com/questions/26911898/
609

610
    """
611
    # Get the coordinates of the line, in plot coordinates
612
    if not isinstance(line, mpl.lines.Line2D):
9✔
613
        raise ValueError("expected a matplotlib.lines.Line2D object")
×
614
    x, y = line.get_xdata(), line.get_ydata()
9✔
615

616
    # Determine the arrow properties
617
    arrow_kw = {"arrowstyle": arrowstyle}
9✔
618

619
    color = line.get_color()
9✔
620
    use_multicolor_lines = isinstance(color, np.ndarray)
9✔
621
    if use_multicolor_lines:
9✔
622
        raise NotImplementedError("multicolor lines not supported")
623
    else:
624
        arrow_kw['color'] = color
9✔
625

626
    linewidth = line.get_linewidth()
9✔
627
    if isinstance(linewidth, np.ndarray):
9✔
628
        raise NotImplementedError("multiwidth lines not supported")
629
    else:
630
        arrow_kw['linewidth'] = linewidth
9✔
631

632
    # Figure out the size of the axes (length of diagonal)
633
    xlim, ylim = axes.get_xlim(), axes.get_ylim()
9✔
634
    ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
9✔
635
    diag = np.linalg.norm(ul - lr)
9✔
636

637
    # Compute the arc length along the curve
638
    s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
9✔
639

640
    # Truncate the number of arrows if the curve is short
641
    # TODO: figure out a smarter way to do this
642
    frac = min(s[-1] / diag, 1)
9✔
643
    if len(arrow_locs) and frac < 0.05:
9✔
644
        arrow_locs = []         # too short; no arrows at all
9✔
645
    elif len(arrow_locs) and frac < 0.2:
9✔
646
        arrow_locs = [0.5]      # single arrow in the middle
9✔
647

648
    # Plot the arrows (and return list if patches)
649
    arrows = []
9✔
650
    for loc in arrow_locs:
9✔
651
        n = np.searchsorted(s, s[-1] * loc)
9✔
652

653
        if dir == 1 and n == 0:
9✔
654
            # Move the arrow forward by one if it is at start of a segment
655
            n = 1
9✔
656

657
        # Place the head of the arrow at the desired location
658
        arrow_head = [x[n], y[n]]
9✔
659
        arrow_tail = [x[n - dir], y[n - dir]]
9✔
660

661
        p = mpl.patches.FancyArrowPatch(
9✔
662
            arrow_tail, arrow_head, transform=axes.transData, lw=0,
663
            **arrow_kw)
664
        axes.add_patch(p)
9✔
665
        arrows.append(p)
9✔
666
    return arrows
9✔
667

668

669
def _get_color_offset(ax, color_cycle=None):
9✔
670
    """Get color offset based on current lines.
671

672
    This function determines that the current offset is for the next color
673
    to use based on current colors in a plot.
674

675
    Parameters
676
    ----------
677
    ax : matplotlib.axes.Axes
678
        Axes containing already plotted lines.
679
    color_cycle : list of matplotlib color specs, optional
680
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
681
        color cycle.
682

683
    Returns
684
    -------
685
    color_offset : matplotlib color spec
686
        Starting color for next line to be drawn.
687
    color_cycle : list of matplotlib color specs
688
        Color cycle used to determine colors.
689

690
    """
691
    if color_cycle is None:
9✔
692
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
693

694
    color_offset = 0
9✔
695
    if len(ax.lines) > 0:
9✔
696
        last_color = ax.lines[-1].get_color()
9✔
697
        if last_color in color_cycle:
9✔
698
            color_offset = color_cycle.index(last_color) + 1
9✔
699

700
    return color_offset % len(color_cycle), color_cycle
9✔
701

702

703
def _get_color(
9✔
704
        colorspec, offset=None, fmt=None, ax=None, lines=None,
705
        color_cycle=None):
706
    """Get color to use for plotting line.
707

708
    This function returns the color to be used for the line to be drawn (or
709
    None if the detault color cycle for the axes should be used).
710

711
    Parameters
712
    ----------
713
    colorspec : matplotlib color specification
714
        User-specified color (or None).
715
    offset : int, optional
716
        Offset into the color cycle (for multi-trace plots).
717
    fmt : str, optional
718
        Format string passed to plotting command.
719
    ax : matplotlib.axes.Axes, optional
720
        Axes containing already plotted lines.
721
    lines : list of matplotlib.lines.Line2D, optional
722
        List of plotted lines.  If not given, use ax.get_lines().
723
    color_cycle : list of matplotlib color specs, optional
724
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
725
        color cycle.
726

727
    Returns
728
    -------
729
    color : matplotlib color spec
730
        Color to use for this line (or None for matplotlib default).
731

732
    """
733
    # See if the color was explicitly specified by the user
734
    if isinstance(colorspec, dict):
9✔
735
        if 'color' in colorspec:
9✔
736
            return colorspec.pop('color')
9✔
737
    elif fmt is not None and \
9✔
738
         [isinstance(arg, str) and
739
          any([c in arg for c in "bgrcmykw#"]) for arg in fmt]:
740
        return None             # *fmt will set the color
9✔
741
    elif colorspec != None:
9✔
742
        return colorspec
9✔
743

744
    # Figure out what color cycle to use, if not given by caller
745
    if color_cycle == None:
9✔
746
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
747

748
    # Find the lines that we should pay attention to
749
    if lines is None and ax is not None:
9✔
750
        lines = ax.lines
9✔
751

752
    # If we were passed a set of lines, try to increment color from previous
753
    if offset is not None:
9✔
754
        return color_cycle[offset]
9✔
755
    elif lines is not None:
9✔
756
        color_offset = 0
9✔
757
        if len(ax.lines) > 0:
9✔
758
            last_color = ax.lines[-1].get_color()
9✔
759
            if last_color in color_cycle:
9✔
760
                color_offset = color_cycle.index(last_color) + 1
9✔
761
        color_offset = color_offset % len(color_cycle)
9✔
762
        return color_cycle[color_offset]
9✔
763
    else:
764
        return None
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc