• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 13229932292

09 Feb 2025 09:52PM UTC coverage: 93.981% (-0.8%) from 94.752%
13229932292

Pull #1112

github

web-flow
Merge 81d9c1e2e into cf77f990b
Pull Request #1112: Use matplotlibs streamplot function for phase_plane_plot

9649 of 10267 relevant lines covered (93.98%)

4.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.81
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Initial author: Richard M. Murray
4
# Creation date: 24 July 2011, converted from MATLAB version (2002);
5
# based on an original version by Kristi Morgansen
6

7
"""Generate 2D phase portraits.
8

9
This module contains functions for generating 2D phase plots. The base
10
function for creating phase plane portraits is `~control.phase_plane_plot`,
11
which generates a phase plane portrait for a 2 state I/O system (with no
12
inputs). Utility functions are available to customize the individual
13
elements of a phase plane portrait.
14

15
The docstring examples assume the following import commands::
16

17
  >>> import numpy as np
18
  >>> import control as ct
19
  >>> import control.phaseplot as pp
20

21
"""
22

23
import math
5✔
24
import warnings
5✔
25

26
import matplotlib as mpl
5✔
27
import matplotlib.pyplot as plt
5✔
28
import numpy as np
5✔
29
from scipy.integrate import odeint
5✔
30

31
from . import config
5✔
32
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
5✔
33
    _process_ax_keyword, _update_plot_title
34
from .exception import ControlArgument
5✔
35
from .nlsys import NonlinearIOSystem, find_operating_point, \
5✔
36
    input_output_response
37

38
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
5✔
39

40
# Default values for module parameter variables
41
_phaseplot_defaults = {
5✔
42
    'phaseplot.arrows': 2,                  # number of arrows around curve
43
    'phaseplot.arrow_size': 8,              # pixel size for arrows
44
    'phaseplot.arrow_style': None,          # set arrow style
45
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
46
}
47

48
def phase_plane_plot(
5✔
49
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
50
        plot_streamlines=None, plot_vectorfield=None, plot_streamplot=None,
51
        plot_equilpoints=True, plot_separatrices=True, ax=None,
52
        suppress_warnings=False, title=None, **kwargs
53
):
54
    """Plot phase plane diagram.
55

56
    This function plots phase plane data, including vector fields, stream
57
    lines, equilibrium points, and contour curves.
58
    If none of plot_streamlines, plot_vectorfield, or plot_streamplot are
59
    set, then plot_streamplot is used by default.
60

61
    Parameters
62
    ----------
63
    sys : `NonlinearIOSystem` or callable(t, x, ...)
64
        I/O system or function used to generate phase plane data. If a
65
        function is given, the remaining arguments are drawn from the
66
        `params` keyword.
67
    pointdata : list or 2D array
68
        List of the form [xmin, xmax, ymin, ymax] describing the
69
        boundaries of the phase plot or an array of shape (N, 2)
70
        giving points of at which to plot the vector field.
71
    timedata : int or list of int
72
        Time to simulate each streamline.  If a list is given, a different
73
        time can be used for each initial condition in `pointdata`.
74
    gridtype : str, optional
75
        The type of grid to use for generating initial conditions:
76
        'meshgrid' (default) generates a mesh of initial conditions within
77
        the specified boundaries, 'boxgrid' generates initial conditions
78
        along the edges of the boundary, 'circlegrid' generates a circle of
79
        initial conditions around each point in point data.
80
    gridspec : list, optional
81
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
82
        size of the grid in the x and y axes on which to generate points.
83
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
84
        specifying the radius and number of points around each point in the
85
        `pointdata` array.
86
    params : dict, optional
87
        Parameters to pass to system. For an I/O system, `params` should be
88
        a dict of parameters and values. For a callable, `params` should be
89
        dict with key 'args' and value given by a tuple (passed to callable).
90
    color : matplotlib color spec, optional
91
        Plot all elements in the given color (use ``plot_<element>`` =
92
        {'color': c} to set the color in one element of the phase
93
        plot (equilpoints, separatrices, streamlines, etc).
94
    ax : `matplotlib.axes.Axes`, optional
95
        The matplotlib axes to draw the figure on.  If not specified and
96
        the current figure has a single axes, that axes is used.
97
        Otherwise, a new figure is created.
98

99
    Returns
100
    -------
101
    cplt : `ControlPlot` object
102
        Object containing the data that were plotted.  See `ControlPlot`
103
        for more detailed information.
104
    cplt.lines : array of list of `matplotlib.lines.Line2D`
105
        Array of list of `matplotlib.artist.Artist` objects:
106

107
            - lines[0] = list of Line2D objects (streamlines, separatrices).
108
            - lines[1] = Quiver object (vector field arrows).
109
            - lines[2] = list of Line2D objects (equilibrium points).
110
            - lines[3] = StreamplotSet object (lines with arrows).
111

112
    cplt.axes : 2D array of `matplotlib.axes.Axes`
113
        Axes for each subplot.
114
    cplt.figure : `matplotlib.figure.Figure`
115
        Figure containing the plot.
116

117
    Other Parameters
118
    ----------------
119
    arrows : int
120
        Set the number of arrows to plot along the streamlines. The default
121
        value can be set in `config.defaults['phaseplot.arrows']`.
122
    arrow_size : float
123
        Set the size of arrows to plot along the streamlines.  The default
124
        value can be set in `config.defaults['phaseplot.arrow_size']`.
125
    arrow_style : matplotlib patch
126
        Set the style of arrows to plot along the streamlines.  The default
127
        value can be set in `config.defaults['phaseplot.arrow_style']`.
128
    dir : str, optional
129
        Direction to draw streamlines: 'forward' to flow forward in time
130
        from the reference points, 'reverse' to flow backward in time, or
131
        'both' to flow both forward and backward.  The amount of time to
132
        simulate in each direction is given by the `timedata` argument.
133
    plot_streamlines : bool or dict, optional
134
        If then plot streamlines based on the pointdata and gridtype.  If set
135
        to a dict, pass on the key-value pairs in the dict as keywords to
136
        `streamlines`.
137
    plot_vectorfield : bool or dict, optional
138
        If then plot the vector field based on the pointdata and gridtype.
139
        If set to a dict, pass on the key-value pairs in the dict as keywords
140
        to `phaseplot.vectorfield`.
141
    plot_streamplot : bool or dict, optional
142
        If then use :func:`matplotlib.axes.Axes.streamplot` function
143
        to plot the streamlines.  If set to a dict, pass on the key-value
144
        pairs in the dict as keywords to :func:`~control.phaseplot.streamplot`.
145
    plot_equilpoints : bool or dict, optional
146
        If True (default) then plot equilibrium points based in the phase
147
        plot boundary. If set to a dict, pass on the key-value pairs in the
148
        dict as keywords to `phaseplot.equilpoints`.
149
    plot_separatrices : bool or dict, optional
150
        If True (default) then plot separatrices starting from each
151
        equilibrium point.  If set to a dict, pass on the key-value pairs
152
        in the dict as keywords to `phaseplot.separatrices`.
153
    rcParams : dict
154
        Override the default parameters used for generating plots.
155
        Default is set by `config.defaults['ctrlplot.rcParams']`.
156
    suppress_warnings : bool, optional
157
        If set to True, suppress warning messages in generating trajectories.
158
    title : str, optional
159
        Set the title of the plot.  Defaults to plot type and system name(s).
160

161
    """
162
    # Check for legacy usage of plot_streamlines
163
    streamline_keywords = [
5✔
164
        'arrows', 'arrow_size', 'arrow_style', 'dir']
165
    if plot_streamlines is None:
5✔
166
        if any([kw in kwargs for kw in streamline_keywords]):
5✔
167
            warnings.warn(
×
168
                "detected streamline keywords; use plot_streamlines to set",
169
                FutureWarning)
170
            plot_streamlines = True
×
171
        if gridtype not in [None, 'meshgrid']:
5✔
172
            warnings.warn(
5✔
173
                "streamplots only support gridtype='meshgrid'; "
174
                "falling back to streamlines")
175
            plot_streamlines = True
5✔
176

177
    if (
5✔
178
        plot_streamlines is None
179
        and plot_vectorfield is None
180
        and plot_streamplot is None
181
    ):
182
        plot_streamplot = True
5✔
183

184
    if plot_streamplot and not plot_streamlines and not plot_vectorfield:
5✔
185
        gridspec = gridspec or [25, 25]
5✔
186

187
    # Process arguments
188
    params = kwargs.get('params', None)
5✔
189
    sys = _create_system(sys, params)
5✔
190
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
5✔
191
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
5✔
192

193
    # Create axis if needed
194
    user_ax = ax
5✔
195
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
5✔
196

197
    # Create copy of kwargs for later checking to find unused arguments
198
    initial_kwargs = dict(kwargs)
5✔
199

200
    # Utility function to create keyword arguments
201
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
5✔
202
        new_kwargs = dict(global_kwargs)
5✔
203
        new_kwargs.update(other_kwargs)
5✔
204
        if isinstance(local_kwargs, dict):
5✔
205
            new_kwargs.update(local_kwargs)
5✔
206
        return new_kwargs
5✔
207

208
    # Create list for storing outputs
209
    out = np.array([[], None, None, None], dtype=object)
5✔
210

211
    # the maximum zorder of stramlines, vectorfield or streamplot
212
    flow_zorder = None
5✔
213

214
    # Plot out the main elements
215
    if plot_streamlines:
5✔
216
        kwargs_local = _create_kwargs(
5✔
217
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
218
            ax=ax)
219
        out[0] += streamlines(
5✔
220
            sys, pointdata, timedata, _check_kwargs=False,
221
            suppress_warnings=suppress_warnings, **kwargs_local)
222
        
223
        new_zorder = max(elem.get_zorder() for elem in out[0])
5✔
224
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
5✔
225

226
        # Get rid of keyword arguments handled by streamlines
227
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
5✔
228
                   'dir', 'params']:
229
            initial_kwargs.pop(kw, None)
5✔
230

231
    # Reset the gridspec for the remaining commands, if needed
232
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
5✔
233
        gridspec = None
×
234

235
    if plot_vectorfield:
5✔
236
        kwargs_local = _create_kwargs(
5✔
237
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
238
        out[1] = vectorfield(
5✔
239
            sys, pointdata, _check_kwargs=False, **kwargs_local)
240
        
241
        new_zorder = out[1].get_zorder()
5✔
242
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
5✔
243

244
        # Get rid of keyword arguments handled by vectorfield
245
        for kw in ['color', 'params']:
5✔
246
            initial_kwargs.pop(kw, None)
5✔
247

248
    if plot_streamplot:
5✔
249
        if gridtype not in [None, 'meshgrid']:
5✔
250
            raise ValueError("gridtype must be 'meshgrid' when using streamplot")
5✔
251

252
        kwargs_local = _create_kwargs(
5✔
253
            kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
254
        out[3] = streamplot(
5✔
255
            sys, pointdata, _check_kwargs=False, **kwargs_local)
256
        
257
        new_zorder = max(out[3].lines.get_zorder(), out[3].arrows.get_zorder())
5✔
258
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
5✔
259

260
        # Get rid of keyword arguments handled by streamplot
261
        for kw in ['color', 'params']:
5✔
262
            initial_kwargs.pop(kw, None)
5✔
263

264
    sep_zorder = flow_zorder + 1 if flow_zorder else None
5✔
265

266
    if plot_separatrices:
5✔
267
        kwargs_local = _create_kwargs(
5✔
268
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
269
        kwargs_local['zorder'] = kwargs_local.get('zorder', sep_zorder)
5✔
270
        out[0] += separatrices(
5✔
271
            sys, pointdata, _check_kwargs=False,  **kwargs_local)
272
        
273
        sep_zorder = max(elem.get_zorder() for elem in out[0]) if out[0] else None
5✔
274

275
        # Get rid of keyword arguments handled by separatrices
276
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
5✔
277
            initial_kwargs.pop(kw, None)
5✔
278

279
    equil_zorder = sep_zorder + 1 if sep_zorder else None
5✔
280

281
    if plot_equilpoints:
5✔
282
        kwargs_local = _create_kwargs(
5✔
283
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
284
        kwargs_local['zorder'] = kwargs_local.get('zorder', equil_zorder)
5✔
285
        out[2] = equilpoints(
5✔
286
            sys, pointdata, _check_kwargs=False, **kwargs_local)
287

288
        # Get rid of keyword arguments handled by equilpoints
289
        for kw in ['params']:
5✔
290
            initial_kwargs.pop(kw, None)
5✔
291

292
    # Make sure all keyword arguments were used
293
    if initial_kwargs:
5✔
294
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
5✔
295

296
    if user_ax is None:
5✔
297
        if title is None:
5✔
298
            title = f"Phase portrait for {sys.name}"
5✔
299
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
5✔
300
        ax.set_xlabel(sys.state_labels[0])
5✔
301
        ax.set_ylabel(sys.state_labels[1])
5✔
302
        plt.tight_layout()
5✔
303

304
    return ControlPlot(out, ax, fig)
5✔
305

306

307
def vectorfield(
5✔
308
        sys, pointdata, gridspec=None, zorder=None, ax=None,
309
        suppress_warnings=False, _check_kwargs=True, **kwargs):
310
    """Plot a vector field in the phase plane.
311

312
    This function plots a vector field for a two-dimensional state
313
    space system.
314

315
    Parameters
316
    ----------
317
    sys : `NonlinearIOSystem` or callable(t, x, ...)
318
        I/O system or function used to generate phase plane data.  If a
319
        function is given, the remaining arguments are drawn from the
320
        `params` keyword.
321
    pointdata : list or 2D array
322
        List of the form [xmin, xmax, ymin, ymax] describing the
323
        boundaries of the phase plot or an array of shape (N, 2)
324
        giving points of at which to plot the vector field.
325
    gridtype : str, optional
326
        The type of grid to use for generating initial conditions:
327
        'meshgrid' (default) generates a mesh of initial conditions within
328
        the specified boundaries, 'boxgrid' generates initial conditions
329
        along the edges of the boundary, 'circlegrid' generates a circle of
330
        initial conditions around each point in point data.
331
    gridspec : list, optional
332
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
333
        size of the grid in the x and y axes on which to generate points.
334
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
335
        specifying the radius and number of points around each point in the
336
        `pointdata` array.
337
    params : dict or list, optional
338
        Parameters to pass to system. For an I/O system, `params` should be
339
        a dict of parameters and values. For a callable, `params` should be
340
        dict with key 'args' and value given by a tuple (passed to callable).
341
    color : matplotlib color spec, optional
342
        Plot the vector field in the given color.
343
    zorder : float, optional
344
        Set the zorder for the separatrices.  In not specified, it will be
345
        automatically chosen by `matplotlib.axes.Axes.quiver`.
346
    ax : `matplotlib.axes.Axes`, optional
347
        Use the given axes for the plot, otherwise use the current axes.
348

349
    Returns
350
    -------
351
    out : Quiver
352

353
    Other Parameters
354
    ----------------
355
    rcParams : dict
356
        Override the default parameters used for generating plots.
357
        Default is set by `config.defaults['ctrlplot.rcParams']`.
358
    suppress_warnings : bool, optional
359
        If set to True, suppress warning messages in generating trajectories.
360

361
    """
362
    # Process keywords
363
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
5✔
364

365
    # Get system parameters
366
    params = kwargs.pop('params', None)
5✔
367

368
    # Create system from callable, if needed
369
    sys = _create_system(sys, params)
5✔
370

371
    # Determine the points on which to generate the vector field
372
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
5✔
373

374
    # Create axis if needed
375
    if ax is None:
5✔
376
        ax = plt.gca()
5✔
377

378
    # Set the plotting limits
379
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
5✔
380

381
    # Figure out the color to use
382
    color = _get_color(kwargs, ax=ax)
5✔
383

384
    # Make sure all keyword arguments were processed
385
    if _check_kwargs and kwargs:
5✔
386
        raise TypeError("unrecognized keywords: ", str(kwargs))
5✔
387

388
    # Generate phase plane (quiver) data
389
    vfdata = np.zeros((points.shape[0], 4))
5✔
390
    sys._update_params(params)
5✔
391
    for i, x in enumerate(points):
5✔
392
        vfdata[i, :2] = x
5✔
393
        vfdata[i, 2:] = sys._rhs(0, x, np.zeros(sys.ninputs))
5✔
394

395
    with plt.rc_context(rcParams):
5✔
396
        out = ax.quiver(
5✔
397
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
398
            angles='xy', color=color, zorder=zorder)
399

400
    return out
5✔
401

402

403
def streamplot(
5✔
404
        sys, pointdata, gridspec=None, zorder=None, ax=None, vary_color=False,
405
        vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
406
        _check_kwargs=True, **kwargs):
407
    """Plot streamlines in the phase plane.
408

409
    This function plots the streamlines for a two-dimensional state
410
    space system using the `matplotlib.axes.Axes.streamplot` function.
411

412
    Parameters
413
    ----------
414
    sys : `NonlinearIOSystem` or callable(t, x, ...)
415
        I/O system or function used to generate phase plane data.  If a
416
        function is given, the remaining arguments are drawn from the
417
        `params` keyword.
418
    pointdata : list or 2D array
419
        List of the form [xmin, xmax, ymin, ymax] describing the
420
        boundaries of the phase plot.
421
    gridspec : list, optional
422
        Specifies the size of the grid in the x and y axes on which to
423
        generate points.
424
    params : dict or list, optional
425
        Parameters to pass to system. For an I/O system, `params` should be
426
        a dict of parameters and values. For a callable, `params` should be
427
        dict with key 'args' and value given by a tuple (passed to callable).
428
    color : matplotlib color spec, optional
429
        Plot the vector field in the given color.
430
    vary_color : bool, optional
431
        If set to True, vary the color of the streamlines based on the magnitude
432
    vary_linewidth : bool, optional.
433
        If set to True, vary the linewidth of the streamlines based on the magnitude.
434
    cmap : str or Colormap, optional
435
        Colormap to use for varying the color of the streamlines.
436
    norm : `matplotlib.colors.Normalize`, optional
437
        An instance of Normalize to use for scaling the colormap and linewidths.
438
    zorder : float, optional
439
        Set the zorder for the separatrices.  In not specified, it will be
440
        automatically chosen by `matplotlib.axes.Axes.streamplot`.
441
    ax : `matplotlib.axes.Axes`, optional
442
        Use the given axes for the plot, otherwise use the current axes.
443

444
    Returns
445
    -------
446
    out : StreamplotSet
447

448
    Other Parameters
449
    ----------------
450
    rcParams : dict
451
        Override the default parameters used for generating plots.
452
        Default is set by `config.default['ctrlplot.rcParams']`.
453
    suppress_warnings : bool, optional
454
        If set to True, suppress warning messages in generating trajectories.
455

456
    """
457
    # Process keywords
458
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
5✔
459

460
    # Get system parameters
461
    params = kwargs.pop('params', None)
5✔
462

463
    # Create system from callable, if needed
464
    sys = _create_system(sys, params)
5✔
465

466
    # Determine the points on which to generate the streamplot field
467
    points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
5✔
468
    grid_arr_shape = gridspec[::-1]
5✔
469
    xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
5✔
470

471
    # Create axis if needed
472
    if ax is None:
5✔
473
        ax = plt.gca()
5✔
474

475
    # Set the plotting limits
476
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
5✔
477

478
    # Figure out the color to use
479
    color = _get_color(kwargs, ax=ax)
5✔
480

481
    # Make sure all keyword arguments were processed
482
    if _check_kwargs and kwargs:
5✔
483
        raise TypeError("unrecognized keywords: ", str(kwargs))
5✔
484

485
    # Generate phase plane (quiver) data
486
    sys._update_params(params)
5✔
487
    us_flat, vs_flat = np.transpose([sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
5✔
488
    us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
5✔
489

490
    magnitudes = np.linalg.norm([us, vs], axis=0)
5✔
491
    norm = norm or mpl.colors.Normalize()
5✔
492
    normalized = norm(magnitudes)
5✔
493
    cmap =  plt.get_cmap(cmap)
5✔
494

495
    with plt.rc_context(rcParams):
5✔
496
        default_lw = plt.rcParams['lines.linewidth']
5✔
497
        min_lw, max_lw = 0.25*default_lw, 2*default_lw
5✔
498
        linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
5✔
499
        color = magnitudes if vary_color else color
5✔
500

501
        out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths,
5✔
502
                            cmap=cmap, norm=norm, zorder=zorder)
503

504
    return out
5✔
505

506
def streamlines(
5✔
507
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
508
        zorder=None, ax=None, _check_kwargs=True, suppress_warnings=False,
509
        **kwargs):
510
    """Plot stream lines in the phase plane.
511

512
    This function plots stream lines for a two-dimensional state space
513
    system.
514

515
    Parameters
516
    ----------
517
    sys : `NonlinearIOSystem` or callable(t, x, ...)
518
        I/O system or function used to generate phase plane data.  If a
519
        function is given, the remaining arguments are drawn from the
520
        `params` keyword.
521
    pointdata : list or 2D array
522
        List of the form [xmin, xmax, ymin, ymax] describing the
523
        boundaries of the phase plot or an array of shape (N, 2)
524
        giving points of at which to plot the vector field.
525
    timedata : int or list of int
526
        Time to simulate each streamline.  If a list is given, a different
527
        time can be used for each initial condition in `pointdata`.
528
    gridtype : str, optional
529
        The type of grid to use for generating initial conditions:
530
        'meshgrid' (default) generates a mesh of initial conditions within
531
        the specified boundaries, 'boxgrid' generates initial conditions
532
        along the edges of the boundary, 'circlegrid' generates a circle of
533
        initial conditions around each point in point data.
534
    gridspec : list, optional
535
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
536
        size of the grid in the x and y axes on which to generate points.
537
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
538
        specifying the radius and number of points around each point in the
539
        `pointdata` array.
540
    dir : str, optional
541
        Direction to draw streamlines: 'forward' to flow forward in time
542
        from the reference points, 'reverse' to flow backward in time, or
543
        'both' to flow both forward and backward.  The amount of time to
544
        simulate in each direction is given by the `timedata` argument.
545
    params : dict or list, optional
546
        Parameters to pass to system. For an I/O system, `params` should be
547
        a dict of parameters and values. For a callable, `params` should be
548
        dict with key 'args' and value given by a tuple (passed to callable).
549
    color : str
550
        Plot the streamlines in the given color.
551
    zorder : float, optional
552
        Set the zorder for the separatrices.  In not specified, it will be
553
        automatically chosen by `matplotlib.axes.Axes.plot`.
554
    ax : `matplotlib.axes.Axes`, optional
555
        Use the given axes for the plot, otherwise use the current axes.
556

557
    Returns
558
    -------
559
    out : list of Line2D objects
560

561
    Other Parameters
562
    ----------------
563
    arrows : int
564
        Set the number of arrows to plot along the streamlines. The default
565
        value can be set in `config.defaults['phaseplot.arrows']`.
566
    arrow_size : float
567
        Set the size of arrows to plot along the streamlines.  The default
568
        value can be set in `config.defaults['phaseplot.arrow_size']`.
569
    arrow_style : matplotlib patch
570
        Set the style of arrows to plot along the streamlines.  The default
571
        value can be set in `config.defaults['phaseplot.arrow_style']`.
572
    rcParams : dict
573
        Override the default parameters used for generating plots.
574
        Default is set by `config.defaults['ctrlplot.rcParams']`.
575
    suppress_warnings : bool, optional
576
        If set to True, suppress warning messages in generating trajectories.
577

578
    """
579
    # Process keywords
580
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
5✔
581

582
    # Get system parameters
583
    params = kwargs.pop('params', None)
5✔
584

585
    # Create system from callable, if needed
586
    sys = _create_system(sys, params)
5✔
587

588
    # Parse the arrows keyword
589
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
5✔
590

591
    # Determine the points on which to generate the streamlines
592
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
5✔
593
    if dir is None:
5✔
594
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
5✔
595

596
    # Create axis if needed
597
    if ax is None:
5✔
598
        ax = plt.gca()
5✔
599

600
    # Set the axis limits
601
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
5✔
602

603
    # Figure out the color to use
604
    color = _get_color(kwargs, ax=ax)
5✔
605

606
    # Make sure all keyword arguments were processed
607
    if _check_kwargs and kwargs:
5✔
608
        raise TypeError("unrecognized keywords: ", str(kwargs))
5✔
609

610
    # Create reverse time system, if needed
611
    if dir != 'forward':
5✔
612
        revsys = NonlinearIOSystem(
5✔
613
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
614
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
615
            outputs=sys.noutputs, params=sys.params)
616
    else:
617
        revsys = None
5✔
618

619
    # Generate phase plane (streamline) data
620
    out = []
5✔
621
    for i, X0 in enumerate(points):
5✔
622
        # Create the trajectory for this point
623
        timepts = _make_timepts(timedata, i)
5✔
624
        traj = _create_trajectory(
5✔
625
            sys, revsys, timepts, X0, params, dir,
626
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
627
            suppress_warnings=suppress_warnings)
628

629
        # Plot the trajectory (if there is one)
630
        if traj.shape[1] > 1:
5✔
631
            with plt.rc_context(rcParams):
5✔
632
                out += ax.plot(traj[0], traj[1], color=color, zorder=zorder)
5✔
633

634
                # Add arrows to the lines at specified intervals
635
                _add_arrows_to_line2D(
5✔
636
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
637
    return out
5✔
638

639

640
def equilpoints(
5✔
641
        sys, pointdata, gridspec=None, color='k', zorder=None, ax=None,
642
        _check_kwargs=True, **kwargs):
643
    """Plot equilibrium points in the phase plane.
644

645
    This function plots the equilibrium points for a planar dynamical system.
646

647
    Parameters
648
    ----------
649
    sys : `NonlinearIOSystem` or callable(t, x, ...)
650
        I/O system or function used to generate phase plane data. If a
651
        function is given, the remaining arguments are drawn from the
652
        `params` keyword.
653
    pointdata : list or 2D array
654
        List of the form [xmin, xmax, ymin, ymax] describing the
655
        boundaries of the phase plot or an array of shape (N, 2)
656
        giving points of at which to plot the vector field.
657
    gridtype : str, optional
658
        The type of grid to use for generating initial conditions:
659
        'meshgrid' (default) generates a mesh of initial conditions within
660
        the specified boundaries, 'boxgrid' generates initial conditions
661
        along the edges of the boundary, 'circlegrid' generates a circle of
662
        initial conditions around each point in point data.
663
    gridspec : list, optional
664
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
665
        size of the grid in the x and y axes on which to generate points.
666
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
667
        specifying the radius and number of points around each point in the
668
        `pointdata` array.
669
    params : dict or list, optional
670
        Parameters to pass to system. For an I/O system, `params` should be
671
        a dict of parameters and values. For a callable, `params` should be
672
        dict with key 'args' and value given by a tuple (passed to callable).
673
    color : str
674
        Plot the equilibrium points in the given color.
675
    zorder : float, optional
676
        Set the zorder for the separatrices.  In not specified, it will be
677
        automatically chosen by `matplotlib.axes.Axes.plot`.
678
    ax : `matplotlib.axes.Axes`, optional
679
        Use the given axes for the plot, otherwise use the current axes.
680

681
    Returns
682
    -------
683
    out : list of Line2D objects
684

685
    Other Parameters
686
    ----------------
687
    rcParams : dict
688
        Override the default parameters used for generating plots.
689
        Default is set by `config.defaults['ctrlplot.rcParams']`.
690

691
    """
692
    # Process keywords
693
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
5✔
694

695
    # Get system parameters
696
    params = kwargs.pop('params', None)
5✔
697

698
    # Create system from callable, if needed
699
    sys = _create_system(sys, params)
5✔
700

701
    # Create axis if needed
702
    if ax is None:
5✔
703
        ax = plt.gca()
5✔
704

705
    # Set the axis limits
706
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
5✔
707

708
    # Determine the points on which to generate the vector field
709
    gridspec = [5, 5] if gridspec is None else gridspec
5✔
710
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
5✔
711

712
    # Make sure all keyword arguments were processed
713
    if _check_kwargs and kwargs:
5✔
714
        raise TypeError("unrecognized keywords: ", str(kwargs))
5✔
715

716
    # Search for equilibrium points
717
    equilpts = _find_equilpts(sys, points, params=params)
5✔
718

719
    # Plot the equilibrium points
720
    out = []
5✔
721
    for xeq in equilpts:
5✔
722
        with plt.rc_context(rcParams):
5✔
723
            out += ax.plot(xeq[0], xeq[1], marker='o', color=color, zorder=zorder)
5✔
724
    return out
5✔
725

726

727
def separatrices(
5✔
728
        sys, pointdata, timedata=None, gridspec=None, zorder=None, ax=None,
729
        _check_kwargs=True, suppress_warnings=False, **kwargs):
730
    """Plot separatrices in the phase plane.
731

732
    This function plots separatrices for a two-dimensional state space
733
    system.
734

735
    Parameters
736
    ----------
737
    sys : `NonlinearIOSystem` or callable(t, x, ...)
738
        I/O system or function used to generate phase plane data. If a
739
        function is given, the remaining arguments are drawn from the
740
        `params` keyword.
741
    pointdata : list or 2D array
742
        List of the form [xmin, xmax, ymin, ymax] describing the
743
        boundaries of the phase plot or an array of shape (N, 2)
744
        giving points of at which to plot the vector field.
745
    timedata : int or list of int
746
        Time to simulate each streamline.  If a list is given, a different
747
        time can be used for each initial condition in `pointdata`.
748
    gridtype : str, optional
749
        The type of grid to use for generating initial conditions:
750
        'meshgrid' (default) generates a mesh of initial conditions within
751
        the specified boundaries, 'boxgrid' generates initial conditions
752
        along the edges of the boundary, 'circlegrid' generates a circle of
753
        initial conditions around each point in point data.
754
    gridspec : list, optional
755
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
756
        size of the grid in the x and y axes on which to generate points.
757
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
758
        specifying the radius and number of points around each point in the
759
        `pointdata` array.
760
    params : dict or list, optional
761
        Parameters to pass to system. For an I/O system, `params` should be
762
        a dict of parameters and values. For a callable, `params` should be
763
        dict with key 'args' and value given by a tuple (passed to callable).
764
    color : matplotlib color spec, optional
765
        Plot the separatrices in the given color.  If a single color
766
        specification is given, this is used for both stable and unstable
767
        separatrices.  If a tuple is given, the first element is used as
768
        the color specification for stable separatrices and the second
769
        element for unstable separatrices.
770
    zorder : float, optional
771
        Set the zorder for the separatrices.  In not specified, it will be
772
        automatically chosen by `matplotlib.axes.Axes.plot`.
773
    ax : `matplotlib.axes.Axes`, optional
774
        Use the given axes for the plot, otherwise use the current axes.
775

776
    Returns
777
    -------
778
    out : list of Line2D objects
779

780
    Other Parameters
781
    ----------------
782
    rcParams : dict
783
        Override the default parameters used for generating plots.
784
        Default is set by `config.defaults['ctrlplot.rcParams']`.
785
    suppress_warnings : bool, optional
786
        If set to True, suppress warning messages in generating trajectories.
787

788
    Notes
789
    -----
790
    The value of `config.defaults['separatrices_radius']` is used to set the
791
    offset from the equilibrium point to the starting point of the separatix
792
    traces, in the direction of the eigenvectors evaluated at that
793
    equilibrium point.
794

795
    """
796
    # Process keywords
797
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
5✔
798

799
    # Get system parameters
800
    params = kwargs.pop('params', None)
5✔
801

802
    # Create system from callable, if needed
803
    sys = _create_system(sys, params)
5✔
804

805
    # Parse the arrows keyword
806
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
5✔
807

808
    # Determine the initial states to use in searching for equilibrium points
809
    gridspec = [5, 5] if gridspec is None else gridspec
5✔
810
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
5✔
811

812
    # Find the equilibrium points
813
    equilpts = _find_equilpts(sys, points, params=params)
5✔
814
    radius = config._get_param('phaseplot', 'separatrices_radius')
5✔
815

816
    # Create axis if needed
817
    if ax is None:
5✔
818
        ax = plt.gca()
5✔
819

820
    # Set the axis limits
821
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
5✔
822

823
    # Figure out the color to use for stable, unstable subspaces
824
    color = _get_color(kwargs)
5✔
825
    match color:
5✔
826
        case None:
5✔
827
            stable_color = 'r'
5✔
828
            unstable_color = 'b'
5✔
829
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
5✔
830
            pass
5✔
831
        case single_color:
5✔
832
            stable_color = unstable_color = single_color
5✔
833

834
    # Make sure all keyword arguments were processed
835
    if _check_kwargs and kwargs:
5✔
836
        raise TypeError("unrecognized keywords: ", str(kwargs))
5✔
837

838
    # Create a "reverse time" system to use for simulation
839
    revsys = NonlinearIOSystem(
5✔
840
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
841
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
842
        outputs=sys.noutputs, params=sys.params)
843

844
    # Plot separatrices by flowing backwards in time along eigenspaces
845
    out = []
5✔
846
    for i, xeq in enumerate(equilpts):
5✔
847
        # Figure out the linearization and eigenvectors
848
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
5✔
849

850
        # See if we have real eigenvalues (=> evecs are meaningful)
851
        if evals[0].imag > 0:
5✔
852
            continue
5✔
853

854
        # Create default list of time points
855
        if timedata is not None:
5✔
856
            timepts = _make_timepts(timedata, i)
5✔
857

858
        # Generate the traces
859
        for j, dir in enumerate(evecs.T):
5✔
860
            # Figure out time vector if not yet computed
861
            if timedata is None:
5✔
862
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
5✔
863
                timepts = np.linspace(0, timescale)
5✔
864

865
            # Run the trajectory starting in eigenvector directions
866
            for eps in [-radius, radius]:
5✔
867
                x0 = xeq + dir * eps
5✔
868
                if evals[j].real < 0:
5✔
869
                    traj = _create_trajectory(
5✔
870
                        sys, revsys, timepts, x0, params, 'reverse',
871
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
872
                        suppress_warnings=suppress_warnings)
873
                    color = stable_color
5✔
874
                    linestyle = '--'
5✔
875
                elif evals[j].real > 0:
5✔
876
                    traj = _create_trajectory(
5✔
877
                        sys, revsys, timepts, x0, params, 'forward',
878
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
879
                        suppress_warnings=suppress_warnings)
880
                    color = unstable_color
5✔
881
                    linestyle = '-'
5✔
882

883
                # Plot the trajectory (if there is one)
884
                if traj.shape[1] > 1:
5✔
885
                    with plt.rc_context(rcParams):
5✔
886
                        out += ax.plot(
5✔
887
                            traj[0], traj[1], color=color, linestyle=linestyle, zorder=zorder)
888

889
                    # Add arrows to the lines at specified intervals
890
                    with plt.rc_context(rcParams):
5✔
891
                        _add_arrows_to_line2D(
5✔
892
                            ax, out[-1], arrow_pos, arrowstyle=arrow_style,
893
                            dir=1)
894
    return out
5✔
895

896

897
#
898
# User accessible utility functions
899
#
900

901
# Utility function to generate boxgrid (in the form needed here)
902
def boxgrid(xvals, yvals):
5✔
903
    """Generate list of points along the edge of box.
904

905
    points = boxgrid(xvals, yvals) generates a list of points that
906
    corresponds to a grid given by the cross product of the x and y values.
907

908
    Parameters
909
    ----------
910
    xvals, yvals : 1D array_like
911
        Array of points defining the points on the lower and left edges of
912
        the box.
913

914
    Returns
915
    -------
916
    grid : 2D array
917
        Array with shape (p, 2) defining the points along the edges of the
918
        box, where p is the number of points around the edge.
919

920
    """
921
    return np.array(
5✔
922
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
923
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
924
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
925
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
926
    )
927

928

929
# Utility function to generate meshgrid (in the form needed here)
930
# TODO: add examples of using grid functions directly
931
def meshgrid(xvals, yvals):
5✔
932
    """Generate list of points forming a mesh.
933

934
    points = meshgrid(xvals, yvals) generates a list of points that
935
    corresponds to a grid given by the cross product of the x and y values.
936

937
    Parameters
938
    ----------
939
    xvals, yvals : 1D array_like
940
        Array of points defining the points on the lower and left edges of
941
        the box.
942

943
    Returns
944
    -------
945
    grid : 2D array
946
        Array of points with shape (n * m, 2) defining the mesh.
947

948
    """
949
    xvals, yvals = np.meshgrid(xvals, yvals)
5✔
950
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
5✔
951
    grid[:, 0] = xvals.reshape(-1)
5✔
952
    grid[:, 1] = yvals.reshape(-1)
5✔
953

954
    return grid
5✔
955

956

957
# Utility function to generate circular grid
958
def circlegrid(centers, radius, num):
5✔
959
    """Generate list of points around a circle.
960

961
    points = circlegrid(centers, radius, num) generates a list of points
962
    that form a circle around a list of centers.
963

964
    Parameters
965
    ----------
966
    centers : 2D array_like
967
        Array of points with shape (p, 2) defining centers of the circles.
968
    radius : float
969
        Radius of the points to be generated around each center.
970
    num : int
971
        Number of points to generate around the circle.
972

973
    Returns
974
    -------
975
    grid : 2D array
976
        Array of points with shape (p * num, 2) defining the circles.
977

978
    """
979
    centers = np.atleast_2d(np.array(centers))
5✔
980
    grid = np.zeros((centers.shape[0] * num, 2))
5✔
981
    for i, center in enumerate(centers):
5✔
982
        grid[i * num: (i + 1) * num, :] = center + np.array([
5✔
983
            [radius * math.cos(theta), radius * math.sin(theta)] for
984
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
985
    return grid
5✔
986

987
#
988
# Internal utility functions
989
#
990

991
# Create a system from a callable
992
def _create_system(sys, params):
5✔
993
    if isinstance(sys, NonlinearIOSystem):
5✔
994
        if sys.nstates != 2:
5✔
995
            raise ValueError("system must be planar")
5✔
996
        return sys
5✔
997

998
    # Make sure that if params is present, it has 'args' key
999
    if params and not params.get('args', None):
5✔
1000
        raise ValueError("params must be dict with key 'args'")
5✔
1001

1002
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
5✔
1003
    _output = lambda t, x, u, params: np.array([])
5✔
1004
    return NonlinearIOSystem(
5✔
1005
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
1006

1007
# Set axis limits for the plot
1008
def _set_axis_limits(ax, pointdata):
5✔
1009
    # Get the current axis limits
1010
    if ax.lines:
5✔
1011
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
5✔
1012
    else:
1013
        # Nothing on the plot => always use new limits
1014
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
5✔
1015

1016
    # Short utility function for updating axis limits
1017
    def _update_limits(cur, new):
5✔
1018
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
5✔
1019

1020
    # If we were passed a box, use that to update the limits
1021
    if isinstance(pointdata, list) and len(pointdata) == 4:
5✔
1022
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
5✔
1023
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
5✔
1024

1025
    elif isinstance(pointdata, np.ndarray):
5✔
1026
        pointdata = np.atleast_2d(pointdata)
5✔
1027
        xlim = _update_limits(
5✔
1028
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
1029
        ylim = _update_limits(
5✔
1030
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
1031

1032
    # Keep track of the largest dimension on the plot
1033
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
5✔
1034

1035
    # Set the new limits
1036
    ax.autoscale(enable=True, axis='x', tight=True)
5✔
1037
    ax.autoscale(enable=True, axis='y', tight=True)
5✔
1038
    ax.set_xlim(xlim)
5✔
1039
    ax.set_ylim(ylim)
5✔
1040

1041
    return xlim, ylim, maxlim
5✔
1042

1043

1044
# Find equilibrium points
1045
def _find_equilpts(sys, points, params=None):
5✔
1046
    equilpts = []
5✔
1047
    for i, x0 in enumerate(points):
5✔
1048
        # Look for an equilibrium point near this point
1049
        xeq, ueq = find_operating_point(sys, x0, 0, params=params)
5✔
1050

1051
        if xeq is None:
5✔
1052
            continue            # didn't find anything
5✔
1053

1054
        # See if we have already found this point
1055
        seen = False
5✔
1056
        for x in equilpts:
5✔
1057
            if np.allclose(np.array(x), xeq):
5✔
1058
                seen = True
5✔
1059
        if seen:
5✔
1060
            continue
5✔
1061

1062
        # Save a new point
1063
        equilpts += [xeq.tolist()]
5✔
1064

1065
    return equilpts
5✔
1066

1067

1068
def _make_points(pointdata, gridspec, gridtype):
5✔
1069
    # Check to see what type of data we got
1070
    if isinstance(pointdata, np.ndarray) and gridtype is None:
5✔
1071
        pointdata = np.atleast_2d(pointdata)
5✔
1072
        if pointdata.shape[1] == 2:
5✔
1073
            # Given a list of points => no action required
1074
            return pointdata, None
5✔
1075

1076
    # Utility function to parse (and check) input arguments
1077
    def _parse_args(defsize):
5✔
1078
        if gridspec is None:
5✔
1079
            return defsize
5✔
1080

1081
        elif not isinstance(gridspec, (list, tuple)) or \
5✔
1082
             len(gridspec) != len(defsize):
1083
            raise ValueError("invalid grid specification")
5✔
1084

1085
        return gridspec
5✔
1086

1087
    # Generate points based on grid type
1088
    match gridtype:
5✔
1089
        case 'boxgrid' | None:
5✔
1090
            gridspec = _parse_args([6, 4])
5✔
1091
            points = boxgrid(
5✔
1092
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1093
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1094

1095
        case 'meshgrid':
5✔
1096
            gridspec = _parse_args([9, 6])
5✔
1097
            points = meshgrid(
5✔
1098
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1099
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1100

1101
        case 'circlegrid':
5✔
1102
            gridspec = _parse_args((0.5, 10))
5✔
1103
            if isinstance(pointdata, np.ndarray):
5✔
1104
                # Create circles around each point
1105
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
5✔
1106
            else:
1107
                # Create circle around center of the plot
1108
                points = circlegrid(
5✔
1109
                    np.array(
1110
                        [(pointdata[0] + pointdata[1]) / 2,
1111
                         (pointdata[0] + pointdata[1]) / 2]),
1112
                    gridspec[0], gridspec[1])
1113

1114
        case _:
5✔
1115
            raise ValueError(f"unknown grid type '{gridtype}'")
5✔
1116

1117
    return points, gridspec
5✔
1118

1119

1120
def _parse_arrow_keywords(kwargs):
5✔
1121
    # Get values for params (and pop from list to allow keyword use in plot)
1122
    # TODO: turn this into a utility function (shared with nyquist_plot?)
1123
    arrows = config._get_param(
5✔
1124
        'phaseplot', 'arrows', kwargs, None, pop=True)
1125
    arrow_size = config._get_param(
5✔
1126
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
1127
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
5✔
1128

1129
    # Parse the arrows keyword
1130
    if not arrows:
5✔
1131
        arrow_pos = []
×
1132
    elif isinstance(arrows, int):
5✔
1133
        N = arrows
5✔
1134
        # Space arrows out, starting midway along each "region"
1135
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
5✔
1136
    elif isinstance(arrows, (list, np.ndarray)):
×
1137
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
1138
    else:
1139
        raise ValueError("unknown or unsupported arrow location")
×
1140

1141
    # Set the arrow style
1142
    if arrow_style is None:
5✔
1143
        arrow_style = mpl.patches.ArrowStyle(
5✔
1144
            'simple', head_width=int(2 * arrow_size / 3),
1145
            head_length=arrow_size)
1146

1147
    return arrow_pos, arrow_style
5✔
1148

1149

1150
# TODO: move to ctrlplot?
1151
def _create_trajectory(
5✔
1152
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
1153
        gridtype=None, gridspec=None, xlim=None, ylim=None):
1154
    # Compute the forward trajectory
1155
    if dir == 'forward' or dir == 'both':
5✔
1156
        fwdresp = input_output_response(
5✔
1157
            sys, timepts, X0=X0, params=params, ignore_errors=True)
1158
        if not fwdresp.success and not suppress_warnings:
5✔
1159
            warnings.warn(f"{X0=}, {fwdresp.message}")
5✔
1160

1161
    # Compute the reverse trajectory
1162
    if dir == 'reverse' or dir == 'both':
5✔
1163
        revresp = input_output_response(
5✔
1164
            revsys, timepts, X0=X0, params=params, ignore_errors=True)
1165
        if not revresp.success and not suppress_warnings:
5✔
1166
            warnings.warn(f"{X0=}, {revresp.message}")
×
1167

1168
    # Create the trace to plot
1169
    if dir == 'forward':
5✔
1170
        traj = fwdresp.states
5✔
1171
    elif dir == 'reverse':
5✔
1172
        traj = revresp.states[:, ::-1]
5✔
1173
    elif dir == 'both':
5✔
1174
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
5✔
1175

1176
    # Remove points outside the window (keep first point beyond boundary)
1177
    inrange = np.asarray(
5✔
1178
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
1179
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
1180
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
5✔
1181
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
5✔
1182

1183
    return traj[:, inrange]
5✔
1184

1185

1186
def _make_timepts(timepts, i):
5✔
1187
    if timepts is None:
5✔
1188
        return np.linspace(0, 1)
5✔
1189
    elif isinstance(timepts, (int, float)):
5✔
1190
        return np.linspace(0, timepts)
5✔
1191
    elif timepts.ndim == 2:
×
1192
        return timepts[i]
×
1193
    return timepts
×
1194

1195

1196
#
1197
# Legacy phase plot function
1198
#
1199
# Author: Richard Murray
1200
# Date: 24 July 2011, converted from MATLAB version (2002); based on
1201
# a version by Kristi Morgansen
1202
#
1203
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
5✔
1204
               lingrid=None, lintime=None, logtime=None, timepts=None,
1205
               parms=None, params=(), tfirst=False, verbose=True):
1206

1207
    """(legacy) Phase plot for 2D dynamical systems.
1208

1209
    .. deprecated:: 0.10.1
1210
        This function is deprecated; use `phase_plane_plot` instead.
1211

1212
    Produces a vector field or stream line plot for a planar system.  This
1213
    function has been replaced by the `phase_plane_map` and
1214
    `phase_plane_plot` functions.
1215

1216
    Call signatures:
1217
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
1218
      phase_plot(func, X, Y, scale, ...) - scale arrows
1219
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
1220
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
1221
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
1222
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
1223

1224
    Parameters
1225
    ----------
1226
    func : callable(x, t, ...)
1227
        Computes the time derivative of y (compatible with odeint).  The
1228
        function should be the same for as used for `scipy.integrate`.
1229
        Namely, it should be a function of the form dx/dt = F(t, x) that
1230
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1231
        dimension 2.
1232
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1233
        Two 3-element sequences specifying x and y coordinates of a
1234
        grid.  These arguments are passed to linspace and meshgrid to
1235
        generate the points at which the vector field is plotted.  If
1236
        absent (or None), the vector field is not plotted.
1237
    scale: float, optional
1238
        Scale size of arrows; default = 1
1239
    X0: ndarray of initial conditions, optional
1240
        List of initial conditions from which streamlines are plotted.
1241
        Each initial condition should be a pair of numbers.
1242
    T: array_like or number, optional
1243
        Length of time to run simulations that generate streamlines.
1244
        If a single number, the same simulation time is used for all
1245
        initial conditions.  Otherwise, should be a list of length
1246
        len(X0) that gives the simulation time for each initial
1247
        condition.  Default value = 50.
1248
    lingrid : integer or 2-tuple of integers, optional
1249
        Argument is either N or (N, M).  If X0 is given and X, Y are
1250
        missing, a grid of arrows is produced using the limits of the
1251
        initial conditions, with N grid points in each dimension or N grid
1252
        points in x and M grid points in y.
1253
    lintime : integer or tuple (integer, float), optional
1254
        If a single integer N is given, draw N arrows using equally space
1255
        time points.  If a tuple (N, lambda) is given, draw N arrows using
1256
        exponential time constant lambda
1257
    timepts : array_like, optional
1258
        Draw arrows at the given list times [t1, t2, ...]
1259
    tfirst : bool, optional
1260
        If True, call `func` with signature ``func(t, x, ...)``.
1261
    params: tuple, optional
1262
        List of parameters to pass to vector field: ``func(x, t, *params)``.
1263

1264
    See Also
1265
    --------
1266
    box_grid
1267

1268
    """
1269
    # Generate a deprecation warning
1270
    warnings.warn(
5✔
1271
        "phase_plot() is deprecated; use phase_plane_plot() instead",
1272
        FutureWarning)
1273

1274
    #
1275
    # Figure out ranges for phase plot (argument processing)
1276
    #
1277
    #! TODO: need to add error checking to arguments
1278
    #! TODO: think through proper action if multiple options are given
1279
    #
1280
    autoFlag = False
5✔
1281
    logtimeFlag = False
5✔
1282
    timeptsFlag = False
5✔
1283
    Narrows = 0
5✔
1284

1285
    # Get parameters to pass to function
1286
    if parms:
5✔
1287
        warnings.warn(
5✔
1288
            "keyword 'parms' is deprecated; use 'params'", FutureWarning)
1289
        if params:
5✔
1290
            raise ControlArgument("duplicate keywords 'parms' and 'params'")
×
1291
        else:
1292
            params = parms
5✔
1293

1294
    if lingrid is not None:
5✔
1295
        autoFlag = True
5✔
1296
        Narrows = lingrid
5✔
1297
        if (verbose):
5✔
1298
            print('Using auto arrows\n')
×
1299

1300
    elif logtime is not None:
5✔
1301
        logtimeFlag = True
5✔
1302
        Narrows = logtime[0]
5✔
1303
        timefactor = logtime[1]
5✔
1304
        if (verbose):
5✔
1305
            print('Using logtime arrows\n')
×
1306

1307
    elif timepts is not None:
5✔
1308
        timeptsFlag = True
5✔
1309
        Narrows = len(timepts)
5✔
1310

1311
    # Figure out the set of points for the quiver plot
1312
    #! TODO: Add sanity checks
1313
    elif X is not None and Y is not None:
5✔
1314
        x1, x2 = np.meshgrid(
5✔
1315
            np.linspace(X[0], X[1], X[2]),
1316
            np.linspace(Y[0], Y[1], Y[2]))
1317
        Narrows = len(x1)
5✔
1318

1319
    else:
1320
        # If we weren't given any grid points, don't plot arrows
1321
        Narrows = 0
5✔
1322

1323
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
5✔
1324
        # Now calculate the vector field at those points
1325
        (nr,nc) = x1.shape
5✔
1326
        dx = np.empty((nr, nc, 2))
5✔
1327
        for i in range(nr):
5✔
1328
            for j in range(nc):
5✔
1329
                if tfirst:
5✔
1330
                    dx[i, j, :] = np.squeeze(
×
1331
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1332
                else:
1333
                    dx[i, j, :] = np.squeeze(
5✔
1334
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1335

1336
        # Plot the quiver plot
1337
        #! TODO: figure out arguments to make arrows show up correctly
1338
        if scale is None:
5✔
1339
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1340
        elif (scale != 0):
5✔
1341
            plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
5✔
1342
                       dx[:,:,1]*np.abs(scale), angles='xy')
1343
            #! TODO: optimize parameters for arrows
1344
            #! TODO: figure out arguments to make arrows show up correctly
1345
            # xy = plt.quiver(...)
1346
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1347

1348
        #! TODO: Tweak the shape of the plot
1349
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1350
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1351
        plt.xlabel('x1'); plt.ylabel('x2')
5✔
1352

1353
    # See if we should also generate the streamlines
1354
    if X0 is None or len(X0) == 0:
5✔
1355
        return
5✔
1356

1357
    # Convert initial conditions to a numpy array
1358
    X0 = np.array(X0)
5✔
1359
    (nr, nc) = np.shape(X0)
5✔
1360

1361
    # Generate some empty matrices to keep arrow information
1362
    x1 = np.empty((nr, Narrows))
5✔
1363
    x2 = np.empty((nr, Narrows))
5✔
1364
    dx = np.empty((nr, Narrows, 2))
5✔
1365

1366
    # See if we were passed a simulation time
1367
    if T is None:
5✔
1368
        T = 50
5✔
1369

1370
    # Parse the time we were passed
1371
    TSPAN = T
5✔
1372
    if isinstance(T, (int, float)):
5✔
1373
        TSPAN = np.linspace(0, T, 100)
5✔
1374

1375
    # Figure out the limits for the plot
1376
    if scale is None:
5✔
1377
        # Assume that the current axis are set as we want them
1378
        alim = plt.axis()
×
1379
        xmin = alim[0]; xmax = alim[1]
×
1380
        ymin = alim[2]; ymax = alim[3]
×
1381
    else:
1382
        # Use the maximum extent of all trajectories
1383
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
5✔
1384
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
5✔
1385

1386
    # Generate the streamlines for each initial condition
1387
    for i in range(nr):
5✔
1388
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
5✔
1389
        time = TSPAN
5✔
1390

1391
        plt.plot(state[:,0], state[:,1])
5✔
1392
        #! TODO: add back in colors for stream lines
1393
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1394
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1395

1396
        # Plot arrows if quiver parameters were 'auto'
1397
        if autoFlag or logtimeFlag or timeptsFlag:
5✔
1398
            # Compute the locations of the arrows
1399
            #! TODO: check this logic to make sure it works in python
1400
            for j in range(Narrows):
5✔
1401

1402
                # Figure out starting index; headless arrows start at 0
1403
                k = -1 if scale is None else 0
5✔
1404

1405
                # Figure out what time index to use for the next point
1406
                if autoFlag:
5✔
1407
                    # Use a linear scaling based on ODE time vector
1408
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1409
                elif logtimeFlag:
5✔
1410
                    # Use an exponential time vector
1411
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1412
                    tarr = _find(time < (j-k) / timefactor)
5✔
1413
                    tind = tarr[-1] if len(tarr) else 0
5✔
1414
                elif timeptsFlag:
5✔
1415
                    # Use specified time points
1416
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1417
                    tarr = _find(time < timepts[j])
5✔
1418
                    tind = tarr[-1] if len(tarr) else 0
5✔
1419

1420
                # For tailless arrows, skip the first point
1421
                if tind == 0 and scale is None:
5✔
1422
                    continue
×
1423

1424
                # Figure out the arrow at this point on the curve
1425
                x1[i,j] = state[tind, 0]
5✔
1426
                x2[i,j] = state[tind, 1]
5✔
1427

1428
                # Skip arrows outside of initial condition box
1429
                if (scale is not None or
5✔
1430
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1431
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1432
                    if tfirst:
5✔
1433
                        pass
×
1434
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1435
                    else:
1436
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
5✔
1437
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
5✔
1438
                else:
1439
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1440

1441
    # Set the plot shape before plotting arrows to avoid warping
1442
    # a=gca
1443
    # if (scale != None):
1444
    #     set(a,'DataAspectRatio', [1,1,1])
1445
    # if (xmin != xmax and ymin != ymax):
1446
    #     plt.axis([xmin, xmax, ymin, ymax])
1447
    # set(a, 'Box', 'on')
1448

1449
    # Plot arrows on the streamlines
1450
    if scale is None and Narrows > 0:
5✔
1451
        # Use a tailless arrow
1452
        #! TODO: figure out arguments to make arrows show up correctly
1453
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1454
    elif scale != 0 and Narrows > 0:
5✔
1455
        plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
5✔
1456
                   angles='xy')
1457
        #! TODO: figure out arguments to make arrows show up correctly
1458
        # xy = plt.quiver(...)
1459
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1460
        # set(xy, 'AutoScale', 'off')
1461
        # set(xy, 'AutoScaleFactor', 0)
1462

1463
    if scale < 0:
5✔
1464
        plt.plot(x1, x2, 'b.');        # add dots at base
×
1465
        # bp = plt.plot(...)
1466
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1467

1468

1469
# Utility function for generating initial conditions around a box
1470
def box_grid(xlimp, ylimp):
5✔
1471
    """Generate list of points on edge of box.
1472

1473
    .. deprecated:: 0.10.0
1474
        Use `phaseplot.boxgrid` instead.
1475

1476
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1477
    list of points that correspond to a uniform grid at the end of the
1478
    box defined by the corners [xmin ymin] and [xmax ymax].
1479

1480
    """
1481

1482
    # Generate a deprecation warning
1483
    warnings.warn(
×
1484
        "box_grid() is deprecated; use phaseplot.boxgrid() instead",
1485
        FutureWarning)
1486

1487
    return boxgrid(
×
1488
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1489
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1490

1491

1492
# TODO: rename to something more useful (or remove??)
1493
def _find(condition):
5✔
1494
    """Returns indices where ravel(a) is true.
1495

1496
    Private implementation of deprecated `matplotlib.mlab.find`.
1497

1498
    """
1499
    return np.nonzero(np.ravel(condition))[0]
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc