• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 13199436913

07 Feb 2025 12:03PM UTC coverage: 94.817% (+0.07%) from 94.752%
13199436913

Pull #1112

github

web-flow
Merge c806ed588 into f73e893e8
Pull Request #1112: Use matplotlibs streamplot function for phase_plane_plot

9732 of 10264 relevant lines covered (94.82%)

8.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.17
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Initial author: Richard M. Murray
4
# Creation date: 24 July 2011, converted from MATLAB version (2002);
5
# based on an original version by Kristi Morgansen
6

7
"""Generate 2D phase portraits.
8

9
This module contains functions for generating 2D phase plots. The base
10
function for creating phase plane portraits is `~control.phase_plane_plot`,
11
which generates a phase plane portrait for a 2 state I/O system (with no
12
inputs). Utility functions are available to customize the individual
13
elements of a phase plane portrait.
14

15
The docstring examples assume the following import commands::
16

17
  >>> import numpy as np
18
  >>> import control as ct
19
  >>> import control.phaseplot as pp
20

21
"""
22

23
import math
9✔
24
import warnings
9✔
25

26
import matplotlib as mpl
9✔
27
import matplotlib.pyplot as plt
9✔
28
import numpy as np
9✔
29
from scipy.integrate import odeint
9✔
30

31
from . import config
9✔
32
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
9✔
33
    _process_ax_keyword, _update_plot_title
34
from .exception import ControlArgument
9✔
35
from .nlsys import NonlinearIOSystem, find_operating_point, \
9✔
36
    input_output_response
37

38
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
9✔
39

40
# Default values for module parameter variables
41
_phaseplot_defaults = {
9✔
42
    'phaseplot.arrows': 2,                  # number of arrows around curve
43
    'phaseplot.arrow_size': 8,              # pixel size for arrows
44
    'phaseplot.arrow_style': None,          # set arrow style
45
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
46
}
47

48
def phase_plane_plot(
9✔
49
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
50
        plot_streamlines=None, plot_vectorfield=None, plot_streamplot=None,
51
        plot_equilpoints=True, plot_separatrices=True, ax=None,
52
        suppress_warnings=False, title=None, **kwargs
53
):
54
    """Plot phase plane diagram.
55

56
    This function plots phase plane data, including vector fields, stream
57
    lines, equilibrium points, and contour curves.
58
    If none of plot_streamlines, plot_vectorfield, or plot_streamplot are
59
    set, then plot_streamplot is used by default.
60

61
    Parameters
62
    ----------
63
    sys : `NonlinearIOSystem` or callable(t, x, ...)
64
        I/O system or function used to generate phase plane data. If a
65
        function is given, the remaining arguments are drawn from the
66
        `params` keyword.
67
    pointdata : list or 2D array
68
        List of the form [xmin, xmax, ymin, ymax] describing the
69
        boundaries of the phase plot or an array of shape (N, 2)
70
        giving points of at which to plot the vector field.
71
    timedata : int or list of int
72
        Time to simulate each streamline.  If a list is given, a different
73
        time can be used for each initial condition in `pointdata`.
74
    gridtype : str, optional
75
        The type of grid to use for generating initial conditions:
76
        'meshgrid' (default) generates a mesh of initial conditions within
77
        the specified boundaries, 'boxgrid' generates initial conditions
78
        along the edges of the boundary, 'circlegrid' generates a circle of
79
        initial conditions around each point in point data.
80
    gridspec : list, optional
81
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
82
        size of the grid in the x and y axes on which to generate points.
83
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
84
        specifying the radius and number of points around each point in the
85
        `pointdata` array.
86
    params : dict, optional
87
        Parameters to pass to system. For an I/O system, `params` should be
88
        a dict of parameters and values. For a callable, `params` should be
89
        dict with key 'args' and value given by a tuple (passed to callable).
90
    color : matplotlib color spec, optional
91
        Plot all elements in the given color (use ``plot_<element>`` =
92
        {'color': c} to set the color in one element of the phase
93
        plot (equilpoints, separatrices, streamlines, etc).
94
    ax : `matplotlib.axes.Axes`, optional
95
        The matplotlib axes to draw the figure on.  If not specified and
96
        the current figure has a single axes, that axes is used.
97
        Otherwise, a new figure is created.
98

99
    Returns
100
    -------
101
    cplt : `ControlPlot` object
102
        Object containing the data that were plotted.  See `ControlPlot`
103
        for more detailed information.
104
    cplt.lines : array of list of `matplotlib.lines.Line2D`
105
        Array of list of `matplotlib.artist.Artist` objects:
106

107
            - lines[0] = list of Line2D objects (streamlines, separatrices).
108
            - lines[1] = Quiver object (vector field arrows).
109
            - lines[2] = list of Line2D objects (equilibrium points).
110
            - lines[3] = StreamplotSet object (lines with arrows).
111

112
    cplt.axes : 2D array of `matplotlib.axes.Axes`
113
        Axes for each subplot.
114
    cplt.figure : `matplotlib.figure.Figure`
115
        Figure containing the plot.
116

117
    Other Parameters
118
    ----------------
119
    arrows : int
120
        Set the number of arrows to plot along the streamlines. The default
121
        value can be set in `config.defaults['phaseplot.arrows']`.
122
    arrow_size : float
123
        Set the size of arrows to plot along the streamlines.  The default
124
        value can be set in `config.defaults['phaseplot.arrow_size']`.
125
    arrow_style : matplotlib patch
126
        Set the style of arrows to plot along the streamlines.  The default
127
        value can be set in `config.defaults['phaseplot.arrow_style']`.
128
    dir : str, optional
129
        Direction to draw streamlines: 'forward' to flow forward in time
130
        from the reference points, 'reverse' to flow backward in time, or
131
        'both' to flow both forward and backward.  The amount of time to
132
        simulate in each direction is given by the `timedata` argument.
133
    plot_streamlines : bool or dict, optional
134
        If then plot streamlines based on the pointdata and gridtype.  If set
135
        to a dict, pass on the key-value pairs in the dict as keywords to
136
        `streamlines`.
137
    plot_vectorfield : bool or dict, optional
138
        If then plot the vector field based on the pointdata and gridtype.
139
        If set to a dict, pass on the key-value pairs in the dict as keywords
140
        to `phaseplot.vectorfield`.
141
    plot_streamplot : bool or dict, optional
142
        If then use :func:`matplotlib.axes.Axes.streamplot` function
143
        to plot the streamlines.  If set to a dict, pass on the key-value
144
        pairs in the dict as keywords to :func:`~control.phaseplot.streamplot`.
145
    plot_equilpoints : bool or dict, optional
146
        If True (default) then plot equilibrium points based in the phase
147
        plot boundary. If set to a dict, pass on the key-value pairs in the
148
        dict as keywords to `phaseplot.equilpoints`.
149
    plot_separatrices : bool or dict, optional
150
        If True (default) then plot separatrices starting from each
151
        equilibrium point.  If set to a dict, pass on the key-value pairs
152
        in the dict as keywords to `phaseplot.separatrices`.
153
    rcParams : dict
154
        Override the default parameters used for generating plots.
155
        Default is set by `config.defaults['ctrlplot.rcParams']`.
156
    suppress_warnings : bool, optional
157
        If set to True, suppress warning messages in generating trajectories.
158
    title : str, optional
159
        Set the title of the plot.  Defaults to plot type and system name(s).
160

161
    """
162
    if (
9✔
163
        plot_streamlines is None
164
        and plot_vectorfield is None
165
        and plot_streamplot is None
166
    ):
167
        plot_streamplot = True
9✔
168

169
    if plot_streamplot and not plot_streamlines and not plot_vectorfield:
9✔
170
        gridspec = gridspec or [25, 25]
9✔
171

172
    # Process arguments
173
    params = kwargs.get('params', None)
9✔
174
    sys = _create_system(sys, params)
9✔
175
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
9✔
176
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
177

178
    # Create axis if needed
179
    user_ax = ax
9✔
180
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
9✔
181

182
    # Create copy of kwargs for later checking to find unused arguments
183
    initial_kwargs = dict(kwargs)
9✔
184

185
    # Utility function to create keyword arguments
186
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
9✔
187
        new_kwargs = dict(global_kwargs)
9✔
188
        new_kwargs.update(other_kwargs)
9✔
189
        if isinstance(local_kwargs, dict):
9✔
190
            new_kwargs.update(local_kwargs)
9✔
191
        return new_kwargs
9✔
192

193
    # Create list for storing outputs
194
    out = np.array([[], None, None, None], dtype=object)
9✔
195

196
    # the maximum zorder of stramlines, vectorfield or streamplot
197
    flow_zorder = None
9✔
198

199
    # Plot out the main elements
200
    if plot_streamlines:
9✔
201
        kwargs_local = _create_kwargs(
9✔
202
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
203
            ax=ax)
204
        out[0] += streamlines(
9✔
205
            sys, pointdata, timedata, _check_kwargs=False,
206
            suppress_warnings=suppress_warnings, **kwargs_local)
207
        
208
        new_zorder = max(elem.get_zorder() for elem in out[0])
9✔
209
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
9✔
210

211
        # Get rid of keyword arguments handled by streamlines
212
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
9✔
213
                   'dir', 'params']:
214
            initial_kwargs.pop(kw, None)
9✔
215

216
    # Reset the gridspec for the remaining commands, if needed
217
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
9✔
218
        gridspec = None
×
219

220
    if plot_vectorfield:
9✔
221
        kwargs_local = _create_kwargs(
9✔
222
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
223
        out[1] = vectorfield(
9✔
224
            sys, pointdata, _check_kwargs=False, **kwargs_local)
225
        
226
        new_zorder = out[1].get_zorder()
9✔
227
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
9✔
228

229
        # Get rid of keyword arguments handled by vectorfield
230
        for kw in ['color', 'params']:
9✔
231
            initial_kwargs.pop(kw, None)
9✔
232

233
    if plot_streamplot:
9✔
234
        if gridtype not in [None, 'meshgrid']:
9✔
235
            raise ValueError("gridtype must be 'meshgrid' when using streamplot")
9✔
236

237
        kwargs_local = _create_kwargs(
9✔
238
            kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
239
        out[3] = streamplot(
9✔
240
            sys, pointdata, _check_kwargs=False, **kwargs_local)
241
        
242
        new_zorder = max(out[3].lines.get_zorder(), out[3].arrows.get_zorder())
9✔
243
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder else new_zorder
9✔
244

245
        # Get rid of keyword arguments handled by streamplot
246
        for kw in ['color', 'params']:
9✔
247
            initial_kwargs.pop(kw, None)
9✔
248

249
    sep_zorder = flow_zorder + 1 if flow_zorder else None
9✔
250

251
    if plot_separatrices:
9✔
252
        kwargs_local = _create_kwargs(
9✔
253
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
254
        kwargs_local['zorder'] = kwargs_local.get('zorder', sep_zorder)
9✔
255
        out[0] += separatrices(
9✔
256
            sys, pointdata, _check_kwargs=False,  **kwargs_local)
257
        
258
        sep_zorder = max(elem.get_zorder() for elem in out[0])
9✔
259

260
        # Get rid of keyword arguments handled by separatrices
261
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
9✔
262
            initial_kwargs.pop(kw, None)
9✔
263

264
    equil_zorder = sep_zorder + 1 if sep_zorder else None
9✔
265

266
    if plot_equilpoints:
9✔
267
        kwargs_local = _create_kwargs(
9✔
268
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
269
        kwargs_local['zorder'] = kwargs_local.get('zorder', equil_zorder)
9✔
270
        out[2] = equilpoints(
9✔
271
            sys, pointdata, _check_kwargs=False, **kwargs_local)
272

273
        # Get rid of keyword arguments handled by equilpoints
274
        for kw in ['params']:
9✔
275
            initial_kwargs.pop(kw, None)
9✔
276

277
    # Make sure all keyword arguments were used
278
    if initial_kwargs:
9✔
279
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
9✔
280

281
    if user_ax is None:
9✔
282
        if title is None:
9✔
283
            title = f"Phase portrait for {sys.name}"
9✔
284
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
9✔
285
        ax.set_xlabel(sys.state_labels[0])
9✔
286
        ax.set_ylabel(sys.state_labels[1])
9✔
287
        plt.tight_layout()
9✔
288

289
    return ControlPlot(out, ax, fig)
9✔
290

291

292
def vectorfield(
9✔
293
        sys, pointdata, gridspec=None, zorder=None, ax=None,
294
        suppress_warnings=False, _check_kwargs=True, **kwargs):
295
    """Plot a vector field in the phase plane.
296

297
    This function plots a vector field for a two-dimensional state
298
    space system.
299

300
    Parameters
301
    ----------
302
    sys : `NonlinearIOSystem` or callable(t, x, ...)
303
        I/O system or function used to generate phase plane data.  If a
304
        function is given, the remaining arguments are drawn from the
305
        `params` keyword.
306
    pointdata : list or 2D array
307
        List of the form [xmin, xmax, ymin, ymax] describing the
308
        boundaries of the phase plot or an array of shape (N, 2)
309
        giving points of at which to plot the vector field.
310
    gridtype : str, optional
311
        The type of grid to use for generating initial conditions:
312
        'meshgrid' (default) generates a mesh of initial conditions within
313
        the specified boundaries, 'boxgrid' generates initial conditions
314
        along the edges of the boundary, 'circlegrid' generates a circle of
315
        initial conditions around each point in point data.
316
    gridspec : list, optional
317
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
318
        size of the grid in the x and y axes on which to generate points.
319
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
320
        specifying the radius and number of points around each point in the
321
        `pointdata` array.
322
    params : dict or list, optional
323
        Parameters to pass to system. For an I/O system, `params` should be
324
        a dict of parameters and values. For a callable, `params` should be
325
        dict with key 'args' and value given by a tuple (passed to callable).
326
    color : matplotlib color spec, optional
327
        Plot the vector field in the given color.
328
    zorder : float, optional
329
        Set the zorder for the separatrices.  In not specified, it will be
330
        automatically chosen by `matplotlib.axes.Axes.quiver`.
331
    ax : `matplotlib.axes.Axes`, optional
332
        Use the given axes for the plot, otherwise use the current axes.
333

334
    Returns
335
    -------
336
    out : Quiver
337

338
    Other Parameters
339
    ----------------
340
    rcParams : dict
341
        Override the default parameters used for generating plots.
342
        Default is set by `config.defaults['ctrlplot.rcParams']`.
343
    suppress_warnings : bool, optional
344
        If set to True, suppress warning messages in generating trajectories.
345

346
    """
347
    # Process keywords
348
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
349

350
    # Get system parameters
351
    params = kwargs.pop('params', None)
9✔
352

353
    # Create system from callable, if needed
354
    sys = _create_system(sys, params)
9✔
355

356
    # Determine the points on which to generate the vector field
357
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
358

359
    # Create axis if needed
360
    if ax is None:
9✔
361
        ax = plt.gca()
9✔
362

363
    # Set the plotting limits
364
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
365

366
    # Figure out the color to use
367
    color = _get_color(kwargs, ax=ax)
9✔
368

369
    # Make sure all keyword arguments were processed
370
    if _check_kwargs and kwargs:
9✔
371
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
372

373
    # Generate phase plane (quiver) data
374
    vfdata = np.zeros((points.shape[0], 4))
9✔
375
    sys._update_params(params)
9✔
376
    for i, x in enumerate(points):
9✔
377
        vfdata[i, :2] = x
9✔
378
        vfdata[i, 2:] = sys._rhs(0, x, np.zeros(sys.ninputs))
9✔
379

380
    with plt.rc_context(rcParams):
9✔
381
        out = ax.quiver(
9✔
382
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
383
            angles='xy', color=color, zorder=zorder)
384

385
    return out
9✔
386

387

388
def streamplot(
9✔
389
        sys, pointdata, gridspec=None, zorder=None, ax=None, vary_color=False,
390
        vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
391
        _check_kwargs=True, **kwargs):
392
    """Plot streamlines in the phase plane.
393

394
    This function plots the streamlines for a two-dimensional state
395
    space system using the `matplotlib.axes.Axes.streamplot` function.
396

397
    Parameters
398
    ----------
399
    sys : `NonlinearIOSystem` or callable(t, x, ...)
400
        I/O system or function used to generate phase plane data.  If a
401
        function is given, the remaining arguments are drawn from the
402
        `params` keyword.
403
    pointdata : list or 2D array
404
        List of the form [xmin, xmax, ymin, ymax] describing the
405
        boundaries of the phase plot.
406
    gridspec : list, optional
407
        Specifies the size of the grid in the x and y axes on which to
408
        generate points.
409
    params : dict or list, optional
410
        Parameters to pass to system. For an I/O system, `params` should be
411
        a dict of parameters and values. For a callable, `params` should be
412
        dict with key 'args' and value given by a tuple (passed to callable).
413
    color : matplotlib color spec, optional
414
        Plot the vector field in the given color.
415
    vary_color : bool, optional
416
        If set to True, vary the color of the streamlines based on the magnitude
417
    vary_linewidth : bool, optional.
418
        If set to True, vary the linewidth of the streamlines based on the magnitude.
419
    cmap : str or Colormap, optional
420
        Colormap to use for varying the color of the streamlines.
421
    norm : `matplotlib.colors.Normalize`, optional
422
        An instance of Normalize to use for scaling the colormap and linewidths.
423
    zorder : float, optional
424
        Set the zorder for the separatrices.  In not specified, it will be
425
        automatically chosen by `matplotlib.axes.Axes.streamplot`.
426
    ax : `matplotlib.axes.Axes`, optional
427
        Use the given axes for the plot, otherwise use the current axes.
428

429
    Returns
430
    -------
431
    out : StreamplotSet
432

433
    Other Parameters
434
    ----------------
435
    rcParams : dict
436
        Override the default parameters used for generating plots.
437
        Default is set by `config.default['ctrlplot.rcParams']`.
438
    suppress_warnings : bool, optional
439
        If set to True, suppress warning messages in generating trajectories.
440

441
    """
442
    # Process keywords
443
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
444

445
    # Get system parameters
446
    params = kwargs.pop('params', None)
9✔
447

448
    # Create system from callable, if needed
449
    sys = _create_system(sys, params)
9✔
450

451
    # Determine the points on which to generate the streamplot field
452
    points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
9✔
453
    grid_arr_shape = gridspec[::-1]
9✔
454
    xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
9✔
455

456
    # Create axis if needed
457
    if ax is None:
9✔
458
        ax = plt.gca()
9✔
459

460
    # Set the plotting limits
461
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
462

463
    # Figure out the color to use
464
    color = _get_color(kwargs, ax=ax)
9✔
465

466
    # Make sure all keyword arguments were processed
467
    if _check_kwargs and kwargs:
9✔
468
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
469

470
    # Generate phase plane (quiver) data
471
    sys._update_params(params)
9✔
472
    us_flat, vs_flat = np.transpose([sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
9✔
473
    us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
9✔
474

475
    magnitudes = np.linalg.norm([us, vs], axis=0)
9✔
476
    norm = norm or mpl.colors.Normalize()
9✔
477
    normalized = norm(magnitudes)
9✔
478
    cmap =  plt.get_cmap(cmap)
9✔
479

480
    with plt.rc_context(rcParams):
9✔
481
        default_lw = plt.rcParams['lines.linewidth']
9✔
482
        min_lw, max_lw = 0.25*default_lw, 2*default_lw
9✔
483
        linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
9✔
484
        color = magnitudes if vary_color else color
9✔
485

486
        out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths,
9✔
487
                            cmap=cmap, norm=norm, zorder=zorder)
488

489
    return out
9✔
490

491
def streamlines(
9✔
492
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
493
        zorder=None, ax=None, _check_kwargs=True, suppress_warnings=False,
494
        **kwargs):
495
    """Plot stream lines in the phase plane.
496

497
    This function plots stream lines for a two-dimensional state space
498
    system.
499

500
    Parameters
501
    ----------
502
    sys : `NonlinearIOSystem` or callable(t, x, ...)
503
        I/O system or function used to generate phase plane data.  If a
504
        function is given, the remaining arguments are drawn from the
505
        `params` keyword.
506
    pointdata : list or 2D array
507
        List of the form [xmin, xmax, ymin, ymax] describing the
508
        boundaries of the phase plot or an array of shape (N, 2)
509
        giving points of at which to plot the vector field.
510
    timedata : int or list of int
511
        Time to simulate each streamline.  If a list is given, a different
512
        time can be used for each initial condition in `pointdata`.
513
    gridtype : str, optional
514
        The type of grid to use for generating initial conditions:
515
        'meshgrid' (default) generates a mesh of initial conditions within
516
        the specified boundaries, 'boxgrid' generates initial conditions
517
        along the edges of the boundary, 'circlegrid' generates a circle of
518
        initial conditions around each point in point data.
519
    gridspec : list, optional
520
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
521
        size of the grid in the x and y axes on which to generate points.
522
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
523
        specifying the radius and number of points around each point in the
524
        `pointdata` array.
525
    dir : str, optional
526
        Direction to draw streamlines: 'forward' to flow forward in time
527
        from the reference points, 'reverse' to flow backward in time, or
528
        'both' to flow both forward and backward.  The amount of time to
529
        simulate in each direction is given by the `timedata` argument.
530
    params : dict or list, optional
531
        Parameters to pass to system. For an I/O system, `params` should be
532
        a dict of parameters and values. For a callable, `params` should be
533
        dict with key 'args' and value given by a tuple (passed to callable).
534
    color : str
535
        Plot the streamlines in the given color.
536
    zorder : float, optional
537
        Set the zorder for the separatrices.  In not specified, it will be
538
        automatically chosen by `matplotlib.axes.Axes.plot`.
539
    ax : `matplotlib.axes.Axes`, optional
540
        Use the given axes for the plot, otherwise use the current axes.
541

542
    Returns
543
    -------
544
    out : list of Line2D objects
545

546
    Other Parameters
547
    ----------------
548
    arrows : int
549
        Set the number of arrows to plot along the streamlines. The default
550
        value can be set in `config.defaults['phaseplot.arrows']`.
551
    arrow_size : float
552
        Set the size of arrows to plot along the streamlines.  The default
553
        value can be set in `config.defaults['phaseplot.arrow_size']`.
554
    arrow_style : matplotlib patch
555
        Set the style of arrows to plot along the streamlines.  The default
556
        value can be set in `config.defaults['phaseplot.arrow_style']`.
557
    rcParams : dict
558
        Override the default parameters used for generating plots.
559
        Default is set by `config.defaults['ctrlplot.rcParams']`.
560
    suppress_warnings : bool, optional
561
        If set to True, suppress warning messages in generating trajectories.
562

563
    """
564
    # Process keywords
565
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
566

567
    # Get system parameters
568
    params = kwargs.pop('params', None)
9✔
569

570
    # Create system from callable, if needed
571
    sys = _create_system(sys, params)
9✔
572

573
    # Parse the arrows keyword
574
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
575

576
    # Determine the points on which to generate the streamlines
577
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
9✔
578
    if dir is None:
9✔
579
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
9✔
580

581
    # Create axis if needed
582
    if ax is None:
9✔
583
        ax = plt.gca()
9✔
584

585
    # Set the axis limits
586
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
587

588
    # Figure out the color to use
589
    color = _get_color(kwargs, ax=ax)
9✔
590

591
    # Make sure all keyword arguments were processed
592
    if _check_kwargs and kwargs:
9✔
593
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
594

595
    # Create reverse time system, if needed
596
    if dir != 'forward':
9✔
597
        revsys = NonlinearIOSystem(
9✔
598
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
599
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
600
            outputs=sys.noutputs, params=sys.params)
601
    else:
602
        revsys = None
9✔
603

604
    # Generate phase plane (streamline) data
605
    out = []
9✔
606
    for i, X0 in enumerate(points):
9✔
607
        # Create the trajectory for this point
608
        timepts = _make_timepts(timedata, i)
9✔
609
        traj = _create_trajectory(
9✔
610
            sys, revsys, timepts, X0, params, dir,
611
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
612
            suppress_warnings=suppress_warnings)
613

614
        # Plot the trajectory (if there is one)
615
        if traj.shape[1] > 1:
9✔
616
            with plt.rc_context(rcParams):
9✔
617
                out += ax.plot(traj[0], traj[1], color=color, zorder=zorder)
9✔
618

619
                # Add arrows to the lines at specified intervals
620
                _add_arrows_to_line2D(
9✔
621
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
622
    return out
9✔
623

624

625
def equilpoints(
9✔
626
        sys, pointdata, gridspec=None, color='k', zorder=None, ax=None,
627
        _check_kwargs=True, **kwargs):
628
    """Plot equilibrium points in the phase plane.
629

630
    This function plots the equilibrium points for a planar dynamical system.
631

632
    Parameters
633
    ----------
634
    sys : `NonlinearIOSystem` or callable(t, x, ...)
635
        I/O system or function used to generate phase plane data. If a
636
        function is given, the remaining arguments are drawn from the
637
        `params` keyword.
638
    pointdata : list or 2D array
639
        List of the form [xmin, xmax, ymin, ymax] describing the
640
        boundaries of the phase plot or an array of shape (N, 2)
641
        giving points of at which to plot the vector field.
642
    gridtype : str, optional
643
        The type of grid to use for generating initial conditions:
644
        'meshgrid' (default) generates a mesh of initial conditions within
645
        the specified boundaries, 'boxgrid' generates initial conditions
646
        along the edges of the boundary, 'circlegrid' generates a circle of
647
        initial conditions around each point in point data.
648
    gridspec : list, optional
649
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
650
        size of the grid in the x and y axes on which to generate points.
651
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
652
        specifying the radius and number of points around each point in the
653
        `pointdata` array.
654
    params : dict or list, optional
655
        Parameters to pass to system. For an I/O system, `params` should be
656
        a dict of parameters and values. For a callable, `params` should be
657
        dict with key 'args' and value given by a tuple (passed to callable).
658
    color : str
659
        Plot the equilibrium points in the given color.
660
    zorder : float, optional
661
        Set the zorder for the separatrices.  In not specified, it will be
662
        automatically chosen by `matplotlib.axes.Axes.plot`.
663
    ax : `matplotlib.axes.Axes`, optional
664
        Use the given axes for the plot, otherwise use the current axes.
665

666
    Returns
667
    -------
668
    out : list of Line2D objects
669

670
    Other Parameters
671
    ----------------
672
    rcParams : dict
673
        Override the default parameters used for generating plots.
674
        Default is set by `config.defaults['ctrlplot.rcParams']`.
675

676
    """
677
    # Process keywords
678
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
679

680
    # Get system parameters
681
    params = kwargs.pop('params', None)
9✔
682

683
    # Create system from callable, if needed
684
    sys = _create_system(sys, params)
9✔
685

686
    # Create axis if needed
687
    if ax is None:
9✔
688
        ax = plt.gca()
9✔
689

690
    # Set the axis limits
691
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
692

693
    # Determine the points on which to generate the vector field
694
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
695
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
696

697
    # Make sure all keyword arguments were processed
698
    if _check_kwargs and kwargs:
9✔
699
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
700

701
    # Search for equilibrium points
702
    equilpts = _find_equilpts(sys, points, params=params)
9✔
703

704
    # Plot the equilibrium points
705
    out = []
9✔
706
    for xeq in equilpts:
9✔
707
        with plt.rc_context(rcParams):
9✔
708
            out += ax.plot(xeq[0], xeq[1], marker='o', color=color, zorder=zorder)
9✔
709
    return out
9✔
710

711

712
def separatrices(
9✔
713
        sys, pointdata, timedata=None, gridspec=None, zorder=None, ax=None,
714
        _check_kwargs=True, suppress_warnings=False, **kwargs):
715
    """Plot separatrices in the phase plane.
716

717
    This function plots separatrices for a two-dimensional state space
718
    system.
719

720
    Parameters
721
    ----------
722
    sys : `NonlinearIOSystem` or callable(t, x, ...)
723
        I/O system or function used to generate phase plane data. If a
724
        function is given, the remaining arguments are drawn from the
725
        `params` keyword.
726
    pointdata : list or 2D array
727
        List of the form [xmin, xmax, ymin, ymax] describing the
728
        boundaries of the phase plot or an array of shape (N, 2)
729
        giving points of at which to plot the vector field.
730
    timedata : int or list of int
731
        Time to simulate each streamline.  If a list is given, a different
732
        time can be used for each initial condition in `pointdata`.
733
    gridtype : str, optional
734
        The type of grid to use for generating initial conditions:
735
        'meshgrid' (default) generates a mesh of initial conditions within
736
        the specified boundaries, 'boxgrid' generates initial conditions
737
        along the edges of the boundary, 'circlegrid' generates a circle of
738
        initial conditions around each point in point data.
739
    gridspec : list, optional
740
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
741
        size of the grid in the x and y axes on which to generate points.
742
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
743
        specifying the radius and number of points around each point in the
744
        `pointdata` array.
745
    params : dict or list, optional
746
        Parameters to pass to system. For an I/O system, `params` should be
747
        a dict of parameters and values. For a callable, `params` should be
748
        dict with key 'args' and value given by a tuple (passed to callable).
749
    color : matplotlib color spec, optional
750
        Plot the separatrices in the given color.  If a single color
751
        specification is given, this is used for both stable and unstable
752
        separatrices.  If a tuple is given, the first element is used as
753
        the color specification for stable separatrices and the second
754
        element for unstable separatrices.
755
    zorder : float, optional
756
        Set the zorder for the separatrices.  In not specified, it will be
757
        automatically chosen by `matplotlib.axes.Axes.plot`.
758
    ax : `matplotlib.axes.Axes`, optional
759
        Use the given axes for the plot, otherwise use the current axes.
760

761
    Returns
762
    -------
763
    out : list of Line2D objects
764

765
    Other Parameters
766
    ----------------
767
    rcParams : dict
768
        Override the default parameters used for generating plots.
769
        Default is set by `config.defaults['ctrlplot.rcParams']`.
770
    suppress_warnings : bool, optional
771
        If set to True, suppress warning messages in generating trajectories.
772

773
    Notes
774
    -----
775
    The value of `config.defaults['separatrices_radius']` is used to set the
776
    offset from the equilibrium point to the starting point of the separatix
777
    traces, in the direction of the eigenvectors evaluated at that
778
    equilibrium point.
779

780
    """
781
    # Process keywords
782
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
783

784
    # Get system parameters
785
    params = kwargs.pop('params', None)
9✔
786

787
    # Create system from callable, if needed
788
    sys = _create_system(sys, params)
9✔
789

790
    # Parse the arrows keyword
791
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
792

793
    # Determine the initial states to use in searching for equilibrium points
794
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
795
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
796

797
    # Find the equilibrium points
798
    equilpts = _find_equilpts(sys, points, params=params)
9✔
799
    radius = config._get_param('phaseplot', 'separatrices_radius')
9✔
800

801
    # Create axis if needed
802
    if ax is None:
9✔
803
        ax = plt.gca()
9✔
804

805
    # Set the axis limits
806
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
807

808
    # Figure out the color to use for stable, unstable subspaces
809
    color = _get_color(kwargs)
9✔
810
    match color:
9✔
811
        case None:
9✔
812
            stable_color = 'r'
9✔
813
            unstable_color = 'b'
9✔
814
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
9✔
815
            pass
9✔
816
        case single_color:
9✔
817
            stable_color = unstable_color = single_color
9✔
818

819
    # Make sure all keyword arguments were processed
820
    if _check_kwargs and kwargs:
9✔
821
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
822

823
    # Create a "reverse time" system to use for simulation
824
    revsys = NonlinearIOSystem(
9✔
825
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
826
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
827
        outputs=sys.noutputs, params=sys.params)
828

829
    # Plot separatrices by flowing backwards in time along eigenspaces
830
    out = []
9✔
831
    for i, xeq in enumerate(equilpts):
9✔
832
        # Plot the equilibrium points
833
        with plt.rc_context(rcParams):
9✔
834
            out += ax.plot(xeq[0], xeq[1], marker='o', color='k', zorder=zorder)
9✔
835

836
        # Figure out the linearization and eigenvectors
837
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
9✔
838

839
        # See if we have real eigenvalues (=> evecs are meaningful)
840
        if evals[0].imag > 0:
9✔
841
            continue
9✔
842

843
        # Create default list of time points
844
        if timedata is not None:
9✔
845
            timepts = _make_timepts(timedata, i)
9✔
846

847
        # Generate the traces
848
        for j, dir in enumerate(evecs.T):
9✔
849
            # Figure out time vector if not yet computed
850
            if timedata is None:
9✔
851
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
9✔
852
                timepts = np.linspace(0, timescale)
9✔
853

854
            # Run the trajectory starting in eigenvector directions
855
            for eps in [-radius, radius]:
9✔
856
                x0 = xeq + dir * eps
9✔
857
                if evals[j].real < 0:
9✔
858
                    traj = _create_trajectory(
9✔
859
                        sys, revsys, timepts, x0, params, 'reverse',
860
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
861
                        suppress_warnings=suppress_warnings)
862
                    color = stable_color
9✔
863
                    linestyle = '--'
9✔
864
                elif evals[j].real > 0:
9✔
865
                    traj = _create_trajectory(
9✔
866
                        sys, revsys, timepts, x0, params, 'forward',
867
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
868
                        suppress_warnings=suppress_warnings)
869
                    color = unstable_color
9✔
870
                    linestyle = '-'
9✔
871

872
                # Plot the trajectory (if there is one)
873
                if traj.shape[1] > 1:
9✔
874
                    with plt.rc_context(rcParams):
9✔
875
                        out += ax.plot(
9✔
876
                            traj[0], traj[1], color=color, linestyle=linestyle, zorder=zorder)
877

878
                    # Add arrows to the lines at specified intervals
879
                    with plt.rc_context(rcParams):
9✔
880
                        _add_arrows_to_line2D(
9✔
881
                            ax, out[-1], arrow_pos, arrowstyle=arrow_style,
882
                            dir=1)
883
    return out
9✔
884

885

886
#
887
# User accessible utility functions
888
#
889

890
# Utility function to generate boxgrid (in the form needed here)
891
def boxgrid(xvals, yvals):
9✔
892
    """Generate list of points along the edge of box.
893

894
    points = boxgrid(xvals, yvals) generates a list of points that
895
    corresponds to a grid given by the cross product of the x and y values.
896

897
    Parameters
898
    ----------
899
    xvals, yvals : 1D array_like
900
        Array of points defining the points on the lower and left edges of
901
        the box.
902

903
    Returns
904
    -------
905
    grid : 2D array
906
        Array with shape (p, 2) defining the points along the edges of the
907
        box, where p is the number of points around the edge.
908

909
    """
910
    return np.array(
9✔
911
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
912
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
913
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
914
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
915
    )
916

917

918
# Utility function to generate meshgrid (in the form needed here)
919
# TODO: add examples of using grid functions directly
920
def meshgrid(xvals, yvals):
9✔
921
    """Generate list of points forming a mesh.
922

923
    points = meshgrid(xvals, yvals) generates a list of points that
924
    corresponds to a grid given by the cross product of the x and y values.
925

926
    Parameters
927
    ----------
928
    xvals, yvals : 1D array_like
929
        Array of points defining the points on the lower and left edges of
930
        the box.
931

932
    Returns
933
    -------
934
    grid : 2D array
935
        Array of points with shape (n * m, 2) defining the mesh.
936

937
    """
938
    xvals, yvals = np.meshgrid(xvals, yvals)
9✔
939
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
9✔
940
    grid[:, 0] = xvals.reshape(-1)
9✔
941
    grid[:, 1] = yvals.reshape(-1)
9✔
942

943
    return grid
9✔
944

945

946
# Utility function to generate circular grid
947
def circlegrid(centers, radius, num):
9✔
948
    """Generate list of points around a circle.
949

950
    points = circlegrid(centers, radius, num) generates a list of points
951
    that form a circle around a list of centers.
952

953
    Parameters
954
    ----------
955
    centers : 2D array_like
956
        Array of points with shape (p, 2) defining centers of the circles.
957
    radius : float
958
        Radius of the points to be generated around each center.
959
    num : int
960
        Number of points to generate around the circle.
961

962
    Returns
963
    -------
964
    grid : 2D array
965
        Array of points with shape (p * num, 2) defining the circles.
966

967
    """
968
    centers = np.atleast_2d(np.array(centers))
9✔
969
    grid = np.zeros((centers.shape[0] * num, 2))
9✔
970
    for i, center in enumerate(centers):
9✔
971
        grid[i * num: (i + 1) * num, :] = center + np.array([
9✔
972
            [radius * math.cos(theta), radius * math.sin(theta)] for
973
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
974
    return grid
9✔
975

976
#
977
# Internal utility functions
978
#
979

980
# Create a system from a callable
981
def _create_system(sys, params):
9✔
982
    if isinstance(sys, NonlinearIOSystem):
9✔
983
        if sys.nstates != 2:
9✔
984
            raise ValueError("system must be planar")
9✔
985
        return sys
9✔
986

987
    # Make sure that if params is present, it has 'args' key
988
    if params and not params.get('args', None):
9✔
989
        raise ValueError("params must be dict with key 'args'")
9✔
990

991
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
9✔
992
    _output = lambda t, x, u, params: np.array([])
9✔
993
    return NonlinearIOSystem(
9✔
994
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
995

996
# Set axis limits for the plot
997
def _set_axis_limits(ax, pointdata):
9✔
998
    # Get the current axis limits
999
    if ax.lines:
9✔
1000
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
9✔
1001
    else:
1002
        # Nothing on the plot => always use new limits
1003
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
9✔
1004

1005
    # Short utility function for updating axis limits
1006
    def _update_limits(cur, new):
9✔
1007
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
9✔
1008

1009
    # If we were passed a box, use that to update the limits
1010
    if isinstance(pointdata, list) and len(pointdata) == 4:
9✔
1011
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
9✔
1012
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
9✔
1013

1014
    elif isinstance(pointdata, np.ndarray):
9✔
1015
        pointdata = np.atleast_2d(pointdata)
9✔
1016
        xlim = _update_limits(
9✔
1017
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
1018
        ylim = _update_limits(
9✔
1019
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
1020

1021
    # Keep track of the largest dimension on the plot
1022
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
9✔
1023

1024
    # Set the new limits
1025
    ax.autoscale(enable=True, axis='x', tight=True)
9✔
1026
    ax.autoscale(enable=True, axis='y', tight=True)
9✔
1027
    ax.set_xlim(xlim)
9✔
1028
    ax.set_ylim(ylim)
9✔
1029

1030
    return xlim, ylim, maxlim
9✔
1031

1032

1033
# Find equilibrium points
1034
def _find_equilpts(sys, points, params=None):
9✔
1035
    equilpts = []
9✔
1036
    for i, x0 in enumerate(points):
9✔
1037
        # Look for an equilibrium point near this point
1038
        xeq, ueq = find_operating_point(sys, x0, 0, params=params)
9✔
1039

1040
        if xeq is None:
9✔
1041
            continue            # didn't find anything
9✔
1042

1043
        # See if we have already found this point
1044
        seen = False
9✔
1045
        for x in equilpts:
9✔
1046
            if np.allclose(np.array(x), xeq):
9✔
1047
                seen = True
9✔
1048
        if seen:
9✔
1049
            continue
9✔
1050

1051
        # Save a new point
1052
        equilpts += [xeq.tolist()]
9✔
1053

1054
    return equilpts
9✔
1055

1056

1057
def _make_points(pointdata, gridspec, gridtype):
9✔
1058
    # Check to see what type of data we got
1059
    if isinstance(pointdata, np.ndarray) and gridtype is None:
9✔
1060
        pointdata = np.atleast_2d(pointdata)
9✔
1061
        if pointdata.shape[1] == 2:
9✔
1062
            # Given a list of points => no action required
1063
            return pointdata, None
9✔
1064

1065
    # Utility function to parse (and check) input arguments
1066
    def _parse_args(defsize):
9✔
1067
        if gridspec is None:
9✔
1068
            return defsize
9✔
1069

1070
        elif not isinstance(gridspec, (list, tuple)) or \
9✔
1071
             len(gridspec) != len(defsize):
1072
            raise ValueError("invalid grid specification")
9✔
1073

1074
        return gridspec
9✔
1075

1076
    # Generate points based on grid type
1077
    match gridtype:
9✔
1078
        case 'boxgrid' | None:
9✔
1079
            gridspec = _parse_args([6, 4])
9✔
1080
            points = boxgrid(
9✔
1081
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1082
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1083

1084
        case 'meshgrid':
9✔
1085
            gridspec = _parse_args([9, 6])
9✔
1086
            points = meshgrid(
9✔
1087
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1088
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1089

1090
        case 'circlegrid':
9✔
1091
            gridspec = _parse_args((0.5, 10))
9✔
1092
            if isinstance(pointdata, np.ndarray):
9✔
1093
                # Create circles around each point
1094
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
9✔
1095
            else:
1096
                # Create circle around center of the plot
1097
                points = circlegrid(
9✔
1098
                    np.array(
1099
                        [(pointdata[0] + pointdata[1]) / 2,
1100
                         (pointdata[0] + pointdata[1]) / 2]),
1101
                    gridspec[0], gridspec[1])
1102

1103
        case _:
9✔
1104
            raise ValueError(f"unknown grid type '{gridtype}'")
9✔
1105

1106
    return points, gridspec
9✔
1107

1108

1109
def _parse_arrow_keywords(kwargs):
9✔
1110
    # Get values for params (and pop from list to allow keyword use in plot)
1111
    # TODO: turn this into a utility function (shared with nyquist_plot?)
1112
    arrows = config._get_param(
9✔
1113
        'phaseplot', 'arrows', kwargs, None, pop=True)
1114
    arrow_size = config._get_param(
9✔
1115
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
1116
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
9✔
1117

1118
    # Parse the arrows keyword
1119
    if not arrows:
9✔
1120
        arrow_pos = []
×
1121
    elif isinstance(arrows, int):
9✔
1122
        N = arrows
9✔
1123
        # Space arrows out, starting midway along each "region"
1124
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
9✔
1125
    elif isinstance(arrows, (list, np.ndarray)):
×
1126
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
1127
    else:
1128
        raise ValueError("unknown or unsupported arrow location")
×
1129

1130
    # Set the arrow style
1131
    if arrow_style is None:
9✔
1132
        arrow_style = mpl.patches.ArrowStyle(
9✔
1133
            'simple', head_width=int(2 * arrow_size / 3),
1134
            head_length=arrow_size)
1135

1136
    return arrow_pos, arrow_style
9✔
1137

1138

1139
# TODO: move to ctrlplot?
1140
def _create_trajectory(
9✔
1141
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
1142
        gridtype=None, gridspec=None, xlim=None, ylim=None):
1143
    # Compute the forward trajectory
1144
    if dir == 'forward' or dir == 'both':
9✔
1145
        fwdresp = input_output_response(
9✔
1146
            sys, timepts, X0=X0, params=params, ignore_errors=True)
1147
        if not fwdresp.success and not suppress_warnings:
9✔
1148
            warnings.warn(f"{X0=}, {fwdresp.message}")
9✔
1149

1150
    # Compute the reverse trajectory
1151
    if dir == 'reverse' or dir == 'both':
9✔
1152
        revresp = input_output_response(
9✔
1153
            revsys, timepts, X0=X0, params=params, ignore_errors=True)
1154
        if not revresp.success and not suppress_warnings:
9✔
1155
            warnings.warn(f"{X0=}, {revresp.message}")
×
1156

1157
    # Create the trace to plot
1158
    if dir == 'forward':
9✔
1159
        traj = fwdresp.states
9✔
1160
    elif dir == 'reverse':
9✔
1161
        traj = revresp.states[:, ::-1]
9✔
1162
    elif dir == 'both':
9✔
1163
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
9✔
1164

1165
    # Remove points outside the window (keep first point beyond boundary)
1166
    inrange = np.asarray(
9✔
1167
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
1168
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
1169
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
9✔
1170
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
9✔
1171

1172
    return traj[:, inrange]
9✔
1173

1174

1175
def _make_timepts(timepts, i):
9✔
1176
    if timepts is None:
9✔
1177
        return np.linspace(0, 1)
9✔
1178
    elif isinstance(timepts, (int, float)):
9✔
1179
        return np.linspace(0, timepts)
9✔
1180
    elif timepts.ndim == 2:
×
1181
        return timepts[i]
×
1182
    return timepts
×
1183

1184

1185
#
1186
# Legacy phase plot function
1187
#
1188
# Author: Richard Murray
1189
# Date: 24 July 2011, converted from MATLAB version (2002); based on
1190
# a version by Kristi Morgansen
1191
#
1192
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
9✔
1193
               lingrid=None, lintime=None, logtime=None, timepts=None,
1194
               parms=None, params=(), tfirst=False, verbose=True):
1195

1196
    """(legacy) Phase plot for 2D dynamical systems.
1197

1198
    .. deprecated:: 0.10.1
1199
        This function is deprecated; use `phase_plane_plot` instead.
1200

1201
    Produces a vector field or stream line plot for a planar system.  This
1202
    function has been replaced by the `phase_plane_map` and
1203
    `phase_plane_plot` functions.
1204

1205
    Call signatures:
1206
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
1207
      phase_plot(func, X, Y, scale, ...) - scale arrows
1208
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
1209
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
1210
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
1211
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
1212

1213
    Parameters
1214
    ----------
1215
    func : callable(x, t, ...)
1216
        Computes the time derivative of y (compatible with odeint).  The
1217
        function should be the same for as used for `scipy.integrate`.
1218
        Namely, it should be a function of the form dx/dt = F(t, x) that
1219
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1220
        dimension 2.
1221
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1222
        Two 3-element sequences specifying x and y coordinates of a
1223
        grid.  These arguments are passed to linspace and meshgrid to
1224
        generate the points at which the vector field is plotted.  If
1225
        absent (or None), the vector field is not plotted.
1226
    scale: float, optional
1227
        Scale size of arrows; default = 1
1228
    X0: ndarray of initial conditions, optional
1229
        List of initial conditions from which streamlines are plotted.
1230
        Each initial condition should be a pair of numbers.
1231
    T: array_like or number, optional
1232
        Length of time to run simulations that generate streamlines.
1233
        If a single number, the same simulation time is used for all
1234
        initial conditions.  Otherwise, should be a list of length
1235
        len(X0) that gives the simulation time for each initial
1236
        condition.  Default value = 50.
1237
    lingrid : integer or 2-tuple of integers, optional
1238
        Argument is either N or (N, M).  If X0 is given and X, Y are
1239
        missing, a grid of arrows is produced using the limits of the
1240
        initial conditions, with N grid points in each dimension or N grid
1241
        points in x and M grid points in y.
1242
    lintime : integer or tuple (integer, float), optional
1243
        If a single integer N is given, draw N arrows using equally space
1244
        time points.  If a tuple (N, lambda) is given, draw N arrows using
1245
        exponential time constant lambda
1246
    timepts : array_like, optional
1247
        Draw arrows at the given list times [t1, t2, ...]
1248
    tfirst : bool, optional
1249
        If True, call `func` with signature ``func(t, x, ...)``.
1250
    params: tuple, optional
1251
        List of parameters to pass to vector field: ``func(x, t, *params)``.
1252

1253
    See Also
1254
    --------
1255
    box_grid
1256

1257
    """
1258
    # Generate a deprecation warning
1259
    warnings.warn(
9✔
1260
        "phase_plot() is deprecated; use phase_plane_plot() instead",
1261
        FutureWarning)
1262

1263
    #
1264
    # Figure out ranges for phase plot (argument processing)
1265
    #
1266
    #! TODO: need to add error checking to arguments
1267
    #! TODO: think through proper action if multiple options are given
1268
    #
1269
    autoFlag = False
9✔
1270
    logtimeFlag = False
9✔
1271
    timeptsFlag = False
9✔
1272
    Narrows = 0
9✔
1273

1274
    # Get parameters to pass to function
1275
    if parms:
9✔
1276
        warnings.warn(
9✔
1277
            "keyword 'parms' is deprecated; use 'params'", FutureWarning)
1278
        if params:
9✔
1279
            raise ControlArgument("duplicate keywords 'parms' and 'params'")
×
1280
        else:
1281
            params = parms
9✔
1282

1283
    if lingrid is not None:
9✔
1284
        autoFlag = True
9✔
1285
        Narrows = lingrid
9✔
1286
        if (verbose):
9✔
1287
            print('Using auto arrows\n')
×
1288

1289
    elif logtime is not None:
9✔
1290
        logtimeFlag = True
9✔
1291
        Narrows = logtime[0]
9✔
1292
        timefactor = logtime[1]
9✔
1293
        if (verbose):
9✔
1294
            print('Using logtime arrows\n')
×
1295

1296
    elif timepts is not None:
9✔
1297
        timeptsFlag = True
9✔
1298
        Narrows = len(timepts)
9✔
1299

1300
    # Figure out the set of points for the quiver plot
1301
    #! TODO: Add sanity checks
1302
    elif X is not None and Y is not None:
9✔
1303
        x1, x2 = np.meshgrid(
9✔
1304
            np.linspace(X[0], X[1], X[2]),
1305
            np.linspace(Y[0], Y[1], Y[2]))
1306
        Narrows = len(x1)
9✔
1307

1308
    else:
1309
        # If we weren't given any grid points, don't plot arrows
1310
        Narrows = 0
9✔
1311

1312
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
9✔
1313
        # Now calculate the vector field at those points
1314
        (nr,nc) = x1.shape
9✔
1315
        dx = np.empty((nr, nc, 2))
9✔
1316
        for i in range(nr):
9✔
1317
            for j in range(nc):
9✔
1318
                if tfirst:
9✔
1319
                    dx[i, j, :] = np.squeeze(
×
1320
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1321
                else:
1322
                    dx[i, j, :] = np.squeeze(
9✔
1323
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1324

1325
        # Plot the quiver plot
1326
        #! TODO: figure out arguments to make arrows show up correctly
1327
        if scale is None:
9✔
1328
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1329
        elif (scale != 0):
9✔
1330
            plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
9✔
1331
                       dx[:,:,1]*np.abs(scale), angles='xy')
1332
            #! TODO: optimize parameters for arrows
1333
            #! TODO: figure out arguments to make arrows show up correctly
1334
            # xy = plt.quiver(...)
1335
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1336

1337
        #! TODO: Tweak the shape of the plot
1338
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1339
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1340
        plt.xlabel('x1'); plt.ylabel('x2')
9✔
1341

1342
    # See if we should also generate the streamlines
1343
    if X0 is None or len(X0) == 0:
9✔
1344
        return
9✔
1345

1346
    # Convert initial conditions to a numpy array
1347
    X0 = np.array(X0)
9✔
1348
    (nr, nc) = np.shape(X0)
9✔
1349

1350
    # Generate some empty matrices to keep arrow information
1351
    x1 = np.empty((nr, Narrows))
9✔
1352
    x2 = np.empty((nr, Narrows))
9✔
1353
    dx = np.empty((nr, Narrows, 2))
9✔
1354

1355
    # See if we were passed a simulation time
1356
    if T is None:
9✔
1357
        T = 50
9✔
1358

1359
    # Parse the time we were passed
1360
    TSPAN = T
9✔
1361
    if isinstance(T, (int, float)):
9✔
1362
        TSPAN = np.linspace(0, T, 100)
9✔
1363

1364
    # Figure out the limits for the plot
1365
    if scale is None:
9✔
1366
        # Assume that the current axis are set as we want them
1367
        alim = plt.axis()
×
1368
        xmin = alim[0]; xmax = alim[1]
×
1369
        ymin = alim[2]; ymax = alim[3]
×
1370
    else:
1371
        # Use the maximum extent of all trajectories
1372
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
9✔
1373
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
9✔
1374

1375
    # Generate the streamlines for each initial condition
1376
    for i in range(nr):
9✔
1377
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
9✔
1378
        time = TSPAN
9✔
1379

1380
        plt.plot(state[:,0], state[:,1])
9✔
1381
        #! TODO: add back in colors for stream lines
1382
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1383
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1384

1385
        # Plot arrows if quiver parameters were 'auto'
1386
        if autoFlag or logtimeFlag or timeptsFlag:
9✔
1387
            # Compute the locations of the arrows
1388
            #! TODO: check this logic to make sure it works in python
1389
            for j in range(Narrows):
9✔
1390

1391
                # Figure out starting index; headless arrows start at 0
1392
                k = -1 if scale is None else 0
9✔
1393

1394
                # Figure out what time index to use for the next point
1395
                if autoFlag:
9✔
1396
                    # Use a linear scaling based on ODE time vector
1397
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1398
                elif logtimeFlag:
9✔
1399
                    # Use an exponential time vector
1400
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1401
                    tarr = _find(time < (j-k) / timefactor)
9✔
1402
                    tind = tarr[-1] if len(tarr) else 0
9✔
1403
                elif timeptsFlag:
9✔
1404
                    # Use specified time points
1405
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1406
                    tarr = _find(time < timepts[j])
9✔
1407
                    tind = tarr[-1] if len(tarr) else 0
9✔
1408

1409
                # For tailless arrows, skip the first point
1410
                if tind == 0 and scale is None:
9✔
1411
                    continue
×
1412

1413
                # Figure out the arrow at this point on the curve
1414
                x1[i,j] = state[tind, 0]
9✔
1415
                x2[i,j] = state[tind, 1]
9✔
1416

1417
                # Skip arrows outside of initial condition box
1418
                if (scale is not None or
9✔
1419
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1420
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1421
                    if tfirst:
9✔
1422
                        pass
×
1423
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1424
                    else:
1425
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
9✔
1426
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
9✔
1427
                else:
1428
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1429

1430
    # Set the plot shape before plotting arrows to avoid warping
1431
    # a=gca
1432
    # if (scale != None):
1433
    #     set(a,'DataAspectRatio', [1,1,1])
1434
    # if (xmin != xmax and ymin != ymax):
1435
    #     plt.axis([xmin, xmax, ymin, ymax])
1436
    # set(a, 'Box', 'on')
1437

1438
    # Plot arrows on the streamlines
1439
    if scale is None and Narrows > 0:
9✔
1440
        # Use a tailless arrow
1441
        #! TODO: figure out arguments to make arrows show up correctly
1442
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1443
    elif scale != 0 and Narrows > 0:
9✔
1444
        plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
9✔
1445
                   angles='xy')
1446
        #! TODO: figure out arguments to make arrows show up correctly
1447
        # xy = plt.quiver(...)
1448
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1449
        # set(xy, 'AutoScale', 'off')
1450
        # set(xy, 'AutoScaleFactor', 0)
1451

1452
    if scale < 0:
9✔
1453
        plt.plot(x1, x2, 'b.');        # add dots at base
×
1454
        # bp = plt.plot(...)
1455
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1456

1457

1458
# Utility function for generating initial conditions around a box
1459
def box_grid(xlimp, ylimp):
9✔
1460
    """Generate list of points on edge of box.
1461

1462
    .. deprecated:: 0.10.0
1463
        Use `phaseplot.boxgrid` instead.
1464

1465
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1466
    list of points that correspond to a uniform grid at the end of the
1467
    box defined by the corners [xmin ymin] and [xmax ymax].
1468

1469
    """
1470

1471
    # Generate a deprecation warning
1472
    warnings.warn(
×
1473
        "box_grid() is deprecated; use phaseplot.boxgrid() instead",
1474
        FutureWarning)
1475

1476
    return boxgrid(
×
1477
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1478
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1479

1480

1481
# TODO: rename to something more useful (or remove??)
1482
def _find(condition):
9✔
1483
    """Returns indices where ravel(a) is true.
1484

1485
    Private implementation of deprecated `matplotlib.mlab.find`.
1486

1487
    """
1488
    return np.nonzero(np.ravel(condition))[0]
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc