• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 13107996232

03 Feb 2025 06:53AM UTC coverage: 94.731% (+0.02%) from 94.709%
13107996232

push

github

web-flow
Merge pull request #1094 from murrayrm/userguide-22Dec2024

Updated user documentation (User Guide, Reference Manual)

9673 of 10211 relevant lines covered (94.73%)

8.28 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.97
control/ctrlplot.py
1
# ctrlplot.py - utility functions for plotting
2
# RMM, 14 Jun 2024
3
#
4

5
"""Utility functions for plotting.
6

7
This module contains a collection of functions that are used by
8
various plotting functions.
9

10
"""
11

12
# Code pattern for control system plotting functions:
13
#
14
# def name_plot(sysdata, *fmt, plot=None, **kwargs):
15
#     # Process keywords and set defaults
16
#     ax = kwargs.pop('ax', None)
17
#     color = kwargs.pop('color', None)
18
#     label = kwargs.pop('label', None)
19
#     rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
20
#
21
#     # Make sure all keyword arguments were processed (if not checked later)
22
#     if kwargs:
23
#         raise TypeError("unrecognized keywords: ", str(kwargs))
24
#
25
#     # Process the data (including generating responses for systems)
26
#     sysdata = list(sysdata)
27
#     if any([isinstance(sys, InputOutputSystem) for sys in sysdata]):
28
#         data = name_response(sysdata)
29
#     nrows = max([data.noutputs for data in sysdata])
30
#     ncols = max([data.ninputs for data in sysdata])
31
#
32
#     # Legacy processing of plot keyword
33
#     if plot is False:
34
#         return data.x, data.y
35
#
36
#     # Figure out the shape of the plot and find/create axes
37
#     fig, ax_array = _process_ax_keyword(ax, (nrows, ncols), rcParams)
38
#     legend_loc, legend_map, show_legend = _process_legend_keywords(
39
#         kwargs, (nrows, ncols), 'center right')
40
#
41
#     # Customize axes (curvilinear grids, shared axes, etc)
42
#
43
#     # Plot the data
44
#     lines = np.full(ax_array.shape, [])
45
#     line_labels = _process_line_labels(label, ntraces, nrows, ncols)
46
#     color_offset, color_cycle = _get_color_offset(ax)
47
#     for i, j in itertools.product(range(nrows), range(ncols)):
48
#         ax = ax_array[i, j]
49
#         for k in range(ntraces):
50
#             if color is None:
51
#                 color = _get_color(
52
#                     color, fmt=fmt, offset=k, color_cycle=color_cycle)
53
#             label = line_labels[k, i, j]
54
#             lines[i, j] += ax.plot(data.x, data.y, color=color, label=label)
55
#
56
#     # Customize and label the axes
57
#     for i, j in itertools.product(range(nrows), range(ncols)):
58
#         ax_array[i, j].set_xlabel("x label")
59
#         ax_array[i, j].set_ylabel("y label")
60
#
61
#     # Create legends
62
#     if show_legend != False:
63
#         legend_array = np.full(ax_array.shape, None, dtype=object)
64
#         for i, j in itertools.product(range(nrows), range(ncols)):
65
#             if legend_map[i, j] is not None:
66
#                 lines = ax_array[i, j].get_lines()
67
#                 labels = _make_legend_labels(lines)
68
#                 if len(labels) > 1:
69
#                     legend_array[i, j] = ax.legend(
70
#                         lines, labels, loc=legend_map[i, j])
71
#     else:
72
#         legend_array = None
73
#
74
#     # Update the plot title (only if ax was not given)
75
#     sysnames = [response.sysname for response in data]
76
#     if ax is None and title is None:
77
#         title = "Name plot for " + ", ".join(sysnames)
78
#         _update_plot_title(title, fig, rcParams=rcParams)
79
#     elif ax == None:
80
#         _update_plot_title(title, fig, rcParams=rcParams, use_existing=False)
81
#
82
#     # Legacy processing of plot keyword
83
#     if plot is True:
84
#         return data
85
#
86
#     return ControlPlot(lines, ax_array, fig, legend=legend_map)
87

88
import itertools
9✔
89
import warnings
9✔
90
from os.path import commonprefix
9✔
91

92
import matplotlib as mpl
9✔
93
import matplotlib.pyplot as plt
9✔
94
import numpy as np
9✔
95

96
from . import config
9✔
97

98
__all__ = [
9✔
99
    'ControlPlot', 'suptitle', 'get_plot_axes', 'pole_zero_subplots',
100
    'rcParams', 'reset_rcParams']
101

102
#
103
# Style parameters
104
#
105

106
rcParams_default = {
9✔
107
    'axes.labelsize': 'small',
108
    'axes.titlesize': 'small',
109
    'figure.titlesize': 'medium',
110
    'legend.fontsize': 'x-small',
111
    'xtick.labelsize': 'small',
112
    'ytick.labelsize': 'small',
113
}
114
_ctrlplot_rcParams = rcParams_default.copy()    # provide access inside module
9✔
115
rcParams = _ctrlplot_rcParams                   # provide access outside module
9✔
116

117
_ctrlplot_defaults = {'ctrlplot.rcParams': _ctrlplot_rcParams}
9✔
118

119

120
#
121
# Control figure
122
#
123

124
class ControlPlot():
9✔
125
    """Return class for control platting functions.
126

127
    This class is used as the return type for control plotting functions.
128
    It contains the information required to access portions of the plot
129
    that the user might want to adjust, as well as providing methods to
130
    modify some of the properties of the plot.
131

132
    A control figure consists of a `matplotlib.figure.Figure` with
133
    an array of `matplotlib.axes.Axes`.  Each axes in the figure has
134
    a number of lines that represent the data for the plot.  There may also
135
    be a legend present in one or more of the axes.
136

137
    Parameters
138
    ----------
139
    lines : array of list of `matplotlib.lines.Line2D`
140
        Array of Line2D objects for each line in the plot.  Generally, the
141
        shape of the array matches the subplots shape and the value of the
142
        array is a list of Line2D objects in that subplot.  Some plotting
143
        functions will return variants of this structure, as described in
144
        the individual documentation for the functions.
145
    axes : 2D array of `matplotlib.axes.Axes`
146
        Array of Axes objects for each subplot in the plot.
147
    figure : `matplotlib.figure.Figure`
148
        Figure on which the Axes are drawn.
149
    legend : `matplotlib.legend.Legend` (instance or ndarray)
150
        Legend object(s) for the plot.  If more than one legend is
151
        included, this will be an array with each entry being either None
152
        (for no legend) or a legend object.
153

154
    """
155
    def __init__(self, lines, axes=None, figure=None, legend=None):
9✔
156
        self.lines = lines
9✔
157
        if axes is None:
9✔
158
            _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
159
            axes = _get_axes(lines)
9✔
160
        self.axes = np.atleast_2d(axes)
9✔
161
        if figure is None:
9✔
162
            figure = self.axes[0, 0].figure
9✔
163
        self.figure = figure
9✔
164
        self.legend = legend
9✔
165

166
    # Implement methods and properties to allow legacy interface (np.array)
167
    __iter__ = lambda self: self.lines
9✔
168
    __len__ = lambda self: len(self.lines)
9✔
169
    def __getitem__(self, item):
9✔
170
        warnings.warn(
9✔
171
            "return of Line2D objects from plot function is deprecated in "
172
            "favor of ControlPlot; use out.lines to access Line2D objects",
173
            category=FutureWarning)
174
        return self.lines[item]
9✔
175
    def __setitem__(self, item, val):
9✔
176
        self.lines[item] = val
×
177
    shape = property(lambda self: self.lines.shape, None)
9✔
178
    def reshape(self, *args):
9✔
179
        """Reshape lines array (legacy)."""
180
        return self.lines.reshape(*args)
9✔
181

182
    def set_plot_title(self, title, frame='axes'):
9✔
183
        """Set the title for a control plot.
184

185
        This is a wrapper for the matplotlib `suptitle` function, but by
186
        setting `frame` to 'axes' (default) then the title is centered on
187
        the midpoint of the axes in the figure, rather than the center of
188
        the figure.  This usually looks better (particularly with
189
        multi-panel plots), though it takes longer to render.
190

191
        Parameters
192
        ----------
193
        title : str
194
            Title text.
195
        fig : Figure, optional
196
            Matplotlib figure.  Defaults to current figure.
197
        frame : str, optional
198
            Coordinate frame for centering: 'axes' (default) or 'figure'.
199
        **kwargs : `matplotlib.pyplot.suptitle` keywords, optional
200
            Additional keywords (passed to matplotlib).
201

202
        """
203
        _update_plot_title(
9✔
204
            title, fig=self.figure, frame=frame, use_existing=False)
205

206
#
207
# User functions
208
#
209
# The functions below can be used by users to modify control plots or get
210
# information about them.
211
#
212

213
def suptitle(
9✔
214
        title, fig=None, frame='axes', **kwargs):
215
    """Add a centered title to a figure.
216

217
    .. deprecated:: 0.10.1
218
        Use `ControlPlot.set_plot_title`.
219

220
    """
221
    warnings.warn(
9✔
222
        "suptitle() is deprecated; use cplt.set_plot_title()", FutureWarning)
223
    _update_plot_title(
9✔
224
        title, fig=fig, frame=frame, use_existing=False, **kwargs)
225

226

227
# Create vectorized function to find axes from lines
228
def get_plot_axes(line_array):
9✔
229
    """Get a list of axes from an array of lines.
230

231
    .. deprecated:: 0.10.1
232
        This function will be removed in a future version of python-control.
233
        Use `cplt.axes` to obtain axes for an instance of `ControlPlot`.
234

235
    This function can be used to return the set of axes corresponding
236
    to the line array that is returned by `time_response_plot`.  This
237
    is useful for generating an axes array that can be passed to
238
    subsequent plotting calls.
239

240
    Parameters
241
    ----------
242
    line_array : array of list of `matplotlib.lines.Line2D`
243
        A 2D array with elements corresponding to a list of lines appearing
244
        in an axes, matching the return type of a time response data plot.
245

246
    Returns
247
    -------
248
    axes_array : array of list of `matplotlib.axes.Axes`
249
        A 2D array with elements corresponding to the Axes associated with
250
        the lines in `line_array`.
251

252
    Notes
253
    -----
254
    Only the first element of each array entry is used to determine the axes.
255

256
    """
257
    warnings.warn(
9✔
258
        "get_plot_axes() is deprecated; use cplt.axes()", FutureWarning)
259
    _get_axes = np.vectorize(lambda lines: lines[0].axes)
9✔
260
    if isinstance(line_array, ControlPlot):
9✔
261
        return _get_axes(line_array.lines)
9✔
262
    else:
263
        return _get_axes(line_array)
9✔
264

265

266
def pole_zero_subplots(
9✔
267
        nrows, ncols, grid=None, dt=None, fig=None, scaling=None,
268
        rcParams=None):
269
    """Create axes for pole/zero plot.
270

271
    Parameters
272
    ----------
273
    nrows, ncols : int
274
        Number of rows and columns.
275
    grid : True, False, or 'empty', optional
276
        Grid style to use.  Can also be a list, in which case each subplot
277
        will have a different style (columns then rows).
278
    dt : timebase, option
279
        Timebase for each subplot (or a list of timebases).
280
    scaling : 'auto', 'equal', or None
281
        Scaling to apply to the subplots.
282
    fig : `matplotlib.figure.Figure`
283
        Figure to use for creating subplots.
284
    rcParams : dict
285
        Override the default parameters used for generating plots.
286
        Default is set by `config.defaults['ctrlplot.rcParams']`.
287

288
    Returns
289
    -------
290
    ax_array : ndarray
291
        2D array of axes.
292

293
    """
294
    from .grid import nogrid, sgrid, zgrid
9✔
295
    from .iosys import isctime
9✔
296

297
    if fig is None:
9✔
298
        fig = plt.gcf()
9✔
299
    rcParams = config._get_param('ctrlplot', 'rcParams', rcParams)
9✔
300

301
    if not isinstance(grid, list):
9✔
302
        grid = [grid] * nrows * ncols
9✔
303
    if not isinstance(dt, list):
9✔
304
        dt = [dt] * nrows * ncols
9✔
305

306
    ax_array = np.full((nrows, ncols), None)
9✔
307
    index = 0
9✔
308
    with plt.rc_context(rcParams):
9✔
309
        for row, col in itertools.product(range(nrows), range(ncols)):
9✔
310
            match grid[index], isctime(dt=dt[index]):
9✔
311
                case 'empty', _:        # empty grid
9✔
312
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
313

314
                case True, True:        # continuous-time grid
9✔
315
                    ax_array[row, col], _ = sgrid(
9✔
316
                        (nrows, ncols, index+1), scaling=scaling)
317

318
                case True, False:       # discrete-time grid
9✔
319
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
320
                    zgrid(ax=ax_array[row, col], scaling=scaling)
9✔
321

322
                case False | None, _:   # no grid (just stability boundaries)
9✔
323
                    ax_array[row, col] = fig.add_subplot(nrows, ncols, index+1)
9✔
324
                    nogrid(
9✔
325
                        ax=ax_array[row, col], dt=dt[index], scaling=scaling)
326
            index += 1
9✔
327
    return ax_array
9✔
328

329

330
def reset_rcParams():
9✔
331
    """Reset rcParams to default values for control plots."""
332
    _ctrlplot_rcParams.update(rcParams_default)
9✔
333

334

335
#
336
# Utility functions
337
#
338
# These functions are used by plotting routines to provide a consistent way
339
# of processing and displaying information.
340
#
341

342
def _process_ax_keyword(
9✔
343
        axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False,
344
        create_axes=True, sharex=False, sharey=False):
345
    """Process ax keyword to plotting commands.
346

347
    This function processes the `ax` keyword to plotting commands.  If no
348
    ax keyword is passed, the current figure is checked to see if it has
349
    the correct shape.  If the shape matches the desired shape, then the
350
    current figure and axes are returned.  Otherwise a new figure is
351
    created with axes of the desired shape.
352

353
    If `create_axes` is False and a new/empty figure is returned, then `axs`
354
    is an array of the proper shape but None for each element.  This allows
355
    the calling function to do the actual axis creation (needed for
356
    curvilinear grids that use the AxisArtist module).
357

358
    Legacy behavior: some of the older plotting commands use a axes label
359
    to identify the proper axes for plotting.  This behavior is supported
360
    through the use of the label keyword, but will only work if shape ==
361
    (1, 1) and squeeze == True.
362

363
    """
364
    if axs is None:
9✔
365
        fig = plt.gcf()         # get current figure (or create new one)
9✔
366
        axs = fig.get_axes()
9✔
367

368
        # Check to see if axes are the right shape; if not, create new figure
369
        # Note: can't actually check the shape, just the total number of axes
370
        if len(axs) != np.prod(shape):
9✔
371
            with plt.rc_context(rcParams):
9✔
372
                if len(axs) != 0 and create_axes:
9✔
373
                    # Create a new figure
374
                    fig, axs = plt.subplots(
9✔
375
                        *shape, sharex=sharex, sharey=sharey, squeeze=False)
376
                elif create_axes:
9✔
377
                    # Create new axes on (empty) figure
378
                    axs = fig.subplots(
9✔
379
                        *shape, sharex=sharex, sharey=sharey, squeeze=False)
380
                else:
381
                    # Create an empty array and let user create axes
382
                    axs = np.full(shape, None)
9✔
383
            if create_axes:     # if not creating axes, leave these to caller
9✔
384
                fig.set_layout_engine('tight')
9✔
385
                fig.align_labels()
9✔
386

387
        else:
388
            # Use the existing axes, properly reshaped
389
            axs = np.asarray(axs).reshape(*shape)
9✔
390

391
            if clear_text:
9✔
392
                # Clear out any old text from the current figure
393
                for text in fig.texts:
9✔
394
                    text.set_visible(False)     # turn off the text
9✔
395
                    del text                    # get rid of it completely
9✔
396
    else:
397
        axs = np.atleast_1d(axs)
9✔
398
        try:
9✔
399
            axs = axs.reshape(shape)
9✔
400
        except ValueError:
9✔
401
            raise ValueError(
9✔
402
                "specified axes are not the right shape; "
403
                f"got {axs.shape} but expecting {shape}")
404
        fig = axs[0, 0].figure
9✔
405

406
    # Process the squeeze keyword
407
    if squeeze and shape == (1, 1):
9✔
408
        axs = axs[0, 0]         # Just return the single axes object
9✔
409
    elif squeeze:
9✔
410
        axs = axs.squeeze()
×
411

412
    return fig, axs
9✔
413

414

415
# Turn label keyword into array indexed by trace, output, input
416
# TODO: move to ctrlutil.py and update parameter names to reflect general use
417
def _process_line_labels(label, ntraces=1, ninputs=0, noutputs=0):
9✔
418
    if label is None:
9✔
419
        return None
9✔
420

421
    if isinstance(label, str):
9✔
422
        label = [label] * ntraces          # single label for all traces
9✔
423

424
    # Convert to an ndarray, if not done already
425
    try:
9✔
426
        line_labels = np.asarray(label)
9✔
427
    except ValueError:
×
428
        raise ValueError("label must be a string or array_like")
×
429

430
    # Turn the data into a 3D array of appropriate shape
431
    # TODO: allow more sophisticated broadcasting (and error checking)
432
    try:
9✔
433
        if ninputs > 0 and noutputs > 0:
9✔
434
            if line_labels.ndim == 1 and line_labels.size == ntraces:
9✔
435
                line_labels = line_labels.reshape(ntraces, 1, 1)
9✔
436
                line_labels = np.broadcast_to(
9✔
437
                    line_labels, (ntraces, ninputs, noutputs))
438
            else:
439
                line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
9✔
440
    except ValueError:
×
441
        if line_labels.shape[0] != ntraces:
×
442
            raise ValueError("number of labels must match number of traces")
×
443
        else:
444
            raise ValueError("labels must be given for each input/output pair")
×
445

446
    return line_labels
9✔
447

448

449
# Get labels for all lines in an axes
450
def _get_line_labels(ax, use_color=True):
9✔
451
    labels_colors, lines = [], []
9✔
452
    last_color, counter = None, 0       # label unknown systems
9✔
453
    for i, line in enumerate(ax.get_lines()):
9✔
454
        label = line.get_label()
9✔
455
        color = line.get_color()
9✔
456
        if use_color and label.startswith("Unknown"):
9✔
457
            label = f"Unknown-{counter}"
×
458
            if last_color != color:
×
459
                counter += 1
×
460
            last_color = color
×
461
        elif label[0] == '_':
9✔
462
            continue
9✔
463

464
        if (label, color) not in labels_colors:
9✔
465
            lines.append(line)
9✔
466
            labels_colors.append((label, color))
9✔
467

468
    return lines, [label for label, color in labels_colors]
9✔
469

470

471
def _process_legend_keywords(
9✔
472
        kwargs, shape=None, default_loc='center right'):
473
    legend_loc = kwargs.pop('legend_loc', None)
9✔
474
    if shape is None and 'legend_map' in kwargs:
9✔
475
        raise TypeError("unexpected keyword argument 'legend_map'")
9✔
476
    else:
477
        legend_map = kwargs.pop('legend_map', None)
9✔
478
    show_legend = kwargs.pop('show_legend', None)
9✔
479

480
    # If legend_loc or legend_map were given, always show the legend
481
    if legend_loc is False or legend_map is False:
9✔
482
        if show_legend is True:
9✔
483
            warnings.warn(
×
484
                "show_legend ignored; legend_loc or legend_map was given")
485
        show_legend = False
9✔
486
        legend_loc = legend_map = None
9✔
487
    elif legend_loc is not None or legend_map is not None:
9✔
488
        if show_legend is False:
9✔
489
            warnings.warn(
×
490
                "show_legend ignored; legend_loc or legend_map was given")
491
        show_legend = True
9✔
492

493
    if legend_loc is None:
9✔
494
        legend_loc = default_loc
9✔
495
    elif not isinstance(legend_loc, (int, str)):
9✔
496
        raise ValueError("legend_loc must be string or int")
×
497

498
    # Make sure the legend map is the right size
499
    if legend_map is not None:
9✔
500
        legend_map = np.atleast_2d(legend_map)
9✔
501
        if legend_map.shape != shape:
9✔
502
            raise ValueError("legend_map shape just match axes shape")
×
503

504
    return legend_loc, legend_map, show_legend
9✔
505

506

507
# Utility function to make legend labels
508
def _make_legend_labels(labels, ignore_common=False):
9✔
509
    if len(labels) == 1:
9✔
510
        return labels
9✔
511

512
    # Look for a common prefix (up to a space)
513
    common_prefix = commonprefix(labels)
9✔
514
    last_space = common_prefix.rfind(', ')
9✔
515
    if last_space < 0 or ignore_common:
9✔
516
        common_prefix = ''
9✔
517
    elif last_space > 0:
9✔
518
        common_prefix = common_prefix[:last_space + 2]
9✔
519
    prefix_len = len(common_prefix)
9✔
520

521
    # Look for a common suffix (up to a space)
522
    common_suffix = commonprefix(
9✔
523
        [label[::-1] for label in labels])[::-1]
524
    suffix_len = len(common_suffix)
9✔
525
    # Only chop things off after a comma or space
526
    while suffix_len > 0 and common_suffix[-suffix_len] != ',':
9✔
527
        suffix_len -= 1
9✔
528

529
    # Strip the labels of common information
530
    if suffix_len > 0 and not ignore_common:
9✔
531
        labels = [label[prefix_len:-suffix_len] for label in labels]
9✔
532
    else:
533
        labels = [label[prefix_len:] for label in labels]
9✔
534

535
    return labels
9✔
536

537

538
def _update_plot_title(
9✔
539
        title, fig=None, frame='axes', use_existing=True, **kwargs):
540
    if title is False or title is None:
9✔
541
        return
9✔
542
    if fig is None:
9✔
543
        fig = plt.gcf()
9✔
544
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
545

546
    if use_existing:
9✔
547
        # Get the current title, if it exists
548
        old_title = None if fig._suptitle is None else fig._suptitle._text
9✔
549

550
        if old_title is not None:
9✔
551
            # Find the common part of the titles
552
            common_prefix = commonprefix([old_title, title])
9✔
553

554
            # Back up to the last space
555
            last_space = common_prefix.rfind(' ')
9✔
556
            if last_space > 0:
9✔
557
                common_prefix = common_prefix[:last_space]
9✔
558
            common_len = len(common_prefix)
9✔
559

560
            # Add the new part of the title (usually the system name)
561
            if old_title[common_len:] != title[common_len:]:
9✔
562
                separator = ',' if len(common_prefix) > 0 else ';'
9✔
563
                title = old_title + separator + title[common_len:]
9✔
564

565
    if frame == 'figure':
9✔
566
        with plt.rc_context(rcParams):
9✔
567
            fig.suptitle(title, **kwargs)
9✔
568

569
    elif frame == 'axes':
9✔
570
        with plt.rc_context(rcParams):
9✔
571
            fig.suptitle(title, **kwargs)           # Place title in center
9✔
572
            plt.tight_layout()                      # Put everything into place
9✔
573
            xc, _ = _find_axes_center(fig, fig.get_axes())
9✔
574
            fig.suptitle(title, x=xc, **kwargs)     # Redraw title, centered
9✔
575

576
    else:
577
        raise ValueError(f"unknown frame '{frame}'")
9✔
578

579

580
def _find_axes_center(fig, axs):
9✔
581
    """Find the midpoint between axes in display coordinates.
582

583
    This function finds the middle of a plot as defined by a set of axes.
584

585
    """
586
    inv_transform = fig.transFigure.inverted()
9✔
587
    xlim = ylim = [1, 0]
9✔
588
    for ax in axs:
9✔
589
        ll = inv_transform.transform(ax.transAxes.transform((0, 0)))
9✔
590
        ur = inv_transform.transform(ax.transAxes.transform((1, 1)))
9✔
591

592
        xlim = [min(ll[0], xlim[0]), max(ur[0], xlim[1])]
9✔
593
        ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]
9✔
594

595
    return (np.sum(xlim)/2, np.sum(ylim)/2)
9✔
596

597

598
# Internal function to add arrows to a curve
599
def _add_arrows_to_line2D(
9✔
600
        axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
601
        arrowstyle='-|>', arrowsize=1, dir=1):
602
    """
603
    Add arrows to a matplotlib.lines.Line2D at selected locations.
604

605
    Parameters
606
    ----------
607
    axes: Axes object as returned by axes command (or gca)
608
    line: Line2D object as returned by plot command
609
    arrow_locs: list of locations where to insert arrows, % of total length
610
    arrowstyle: style of the arrow
611
    arrowsize: size of the arrow
612

613
    Returns
614
    -------
615
    arrows : list of arrows
616

617
    Notes
618
    -----
619
    Based on https://stackoverflow.com/questions/26911898/
620

621
    """
622
    # Get the coordinates of the line, in plot coordinates
623
    if not isinstance(line, mpl.lines.Line2D):
9✔
624
        raise ValueError("expected a matplotlib.lines.Line2D object")
×
625
    x, y = line.get_xdata(), line.get_ydata()
9✔
626

627
    # Determine the arrow properties
628
    arrow_kw = {"arrowstyle": arrowstyle}
9✔
629

630
    color = line.get_color()
9✔
631
    use_multicolor_lines = isinstance(color, np.ndarray)
9✔
632
    if use_multicolor_lines:
9✔
633
        raise NotImplementedError("multi-color lines not supported")
634
    else:
635
        arrow_kw['color'] = color
9✔
636

637
    linewidth = line.get_linewidth()
9✔
638
    if isinstance(linewidth, np.ndarray):
9✔
639
        raise NotImplementedError("multi-width lines not supported")
640
    else:
641
        arrow_kw['linewidth'] = linewidth
9✔
642

643
    # Figure out the size of the axes (length of diagonal)
644
    xlim, ylim = axes.get_xlim(), axes.get_ylim()
9✔
645
    ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
9✔
646
    diag = np.linalg.norm(ul - lr)
9✔
647

648
    # Compute the arc length along the curve
649
    s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
9✔
650

651
    # Truncate the number of arrows if the curve is short
652
    # TODO: figure out a smarter way to do this
653
    frac = min(s[-1] / diag, 1)
9✔
654
    if len(arrow_locs) and frac < 0.05:
9✔
655
        arrow_locs = []         # too short; no arrows at all
9✔
656
    elif len(arrow_locs) and frac < 0.2:
9✔
657
        arrow_locs = [0.5]      # single arrow in the middle
9✔
658

659
    # Plot the arrows (and return list if patches)
660
    arrows = []
9✔
661
    for loc in arrow_locs:
9✔
662
        n = np.searchsorted(s, s[-1] * loc)
9✔
663

664
        if dir == 1 and n == 0:
9✔
665
            # Move the arrow forward by one if it is at start of a segment
666
            n = 1
9✔
667

668
        # Place the head of the arrow at the desired location
669
        arrow_head = [x[n], y[n]]
9✔
670
        arrow_tail = [x[n - dir], y[n - dir]]
9✔
671

672
        p = mpl.patches.FancyArrowPatch(
9✔
673
            arrow_tail, arrow_head, transform=axes.transData, lw=0,
674
            **arrow_kw)
675
        axes.add_patch(p)
9✔
676
        arrows.append(p)
9✔
677
    return arrows
9✔
678

679

680
def _get_color_offset(ax, color_cycle=None):
9✔
681
    """Get color offset based on current lines.
682

683
    This function determines that the current offset is for the next color
684
    to use based on current colors in a plot.
685

686
    Parameters
687
    ----------
688
    ax : `matplotlib.axes.Axes`
689
        Axes containing already plotted lines.
690
    color_cycle : list of matplotlib color specs, optional
691
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
692
        color cycle.
693

694
    Returns
695
    -------
696
    color_offset : matplotlib color spec
697
        Starting color for next line to be drawn.
698
    color_cycle : list of matplotlib color specs
699
        Color cycle used to determine colors.
700

701
    """
702
    if color_cycle is None:
9✔
703
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
704

705
    color_offset = 0
9✔
706
    if len(ax.lines) > 0:
9✔
707
        last_color = ax.lines[-1].get_color()
9✔
708
        if last_color in color_cycle:
9✔
709
            color_offset = color_cycle.index(last_color) + 1
9✔
710

711
    return color_offset % len(color_cycle), color_cycle
9✔
712

713

714
def _get_color(
9✔
715
        colorspec, offset=None, fmt=None, ax=None, lines=None,
716
        color_cycle=None):
717
    """Get color to use for plotting line.
718

719
    This function returns the color to be used for the line to be drawn (or
720
    None if the default color cycle for the axes should be used).
721

722
    Parameters
723
    ----------
724
    colorspec : matplotlib color specification
725
        User-specified color (or None).
726
    offset : int, optional
727
        Offset into the color cycle (for multi-trace plots).
728
    fmt : str, optional
729
        Format string passed to plotting command.
730
    ax : `matplotlib.axes.Axes`, optional
731
        Axes containing already plotted lines.
732
    lines : list of matplotlib.lines.Line2D, optional
733
        List of plotted lines.  If not given, use ax.get_lines().
734
    color_cycle : list of matplotlib color specs, optional
735
        Colors to use in plotting lines.  Defaults to matplotlib rcParams
736
        color cycle.
737

738
    Returns
739
    -------
740
    color : matplotlib color spec
741
        Color to use for this line (or None for matplotlib default).
742

743
    """
744
    # See if the color was explicitly specified by the user
745
    if isinstance(colorspec, dict):
9✔
746
        if 'color' in colorspec:
9✔
747
            return colorspec.pop('color')
9✔
748
    elif fmt is not None and \
9✔
749
         [isinstance(arg, str) and
750
          any([c in arg for c in "bgrcmykw#"]) for arg in fmt]:
751
        return None             # *fmt will set the color
9✔
752
    elif colorspec != None:
9✔
753
        return colorspec
9✔
754

755
    # Figure out what color cycle to use, if not given by caller
756
    if color_cycle == None:
9✔
757
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
9✔
758

759
    # Find the lines that we should pay attention to
760
    if lines is None and ax is not None:
9✔
761
        lines = ax.lines
9✔
762

763
    # If we were passed a set of lines, try to increment color from previous
764
    if offset is not None:
9✔
765
        return color_cycle[offset]
9✔
766
    elif lines is not None:
9✔
767
        color_offset = 0
9✔
768
        if len(ax.lines) > 0:
9✔
769
            last_color = ax.lines[-1].get_color()
9✔
770
            if last_color in color_cycle:
9✔
771
                color_offset = color_cycle.index(last_color) + 1
9✔
772
        color_offset = color_offset % len(color_cycle)
9✔
773
        return color_cycle[color_offset]
9✔
774
    else:
775
        return None
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc