• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 13144560479

04 Feb 2025 08:48PM UTC coverage: 94.725% (-0.03%) from 94.752%
13144560479

Pull #1112

github

web-flow
Merge ad4083259 into f73e893e8
Pull Request #1112: Use matplotlibs streamplot function for phase_plane_plot

9715 of 10256 relevant lines covered (94.73%)

8.28 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.01
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Initial author: Richard M. Murray
4
# Creation date: 24 July 2011, converted from MATLAB version (2002);
5
# based on an original version by Kristi Morgansen
6

7
"""Generate 2D phase portraits.
8

9
This module contains functions for generating 2D phase plots. The base
10
function for creating phase plane portraits is `~control.phase_plane_plot`,
11
which generates a phase plane portrait for a 2 state I/O system (with no
12
inputs). Utility functions are available to customize the individual
13
elements of a phase plane portrait.
14

15
The docstring examples assume the following import commands::
16

17
  >>> import numpy as np
18
  >>> import control as ct
19
  >>> import control.phaseplot as pp
20

21
"""
22

23
import math
9✔
24
import warnings
9✔
25

26
import matplotlib as mpl
9✔
27
import matplotlib.pyplot as plt
9✔
28
import numpy as np
9✔
29
from scipy.integrate import odeint
9✔
30

31
from . import config
9✔
32
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
9✔
33
    _process_ax_keyword, _update_plot_title
34
from .exception import ControlArgument
9✔
35
from .nlsys import NonlinearIOSystem, find_operating_point, \
9✔
36
    input_output_response
37

38
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
9✔
39

40
# Default values for module parameter variables
41
_phaseplot_defaults = {
9✔
42
    'phaseplot.arrows': 2,                  # number of arrows around curve
43
    'phaseplot.arrow_size': 8,              # pixel size for arrows
44
    'phaseplot.arrow_style': None,          # set arrow style
45
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
46
}
47

48
def phase_plane_plot(
9✔
49
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
50
        plot_streamlines=None, plot_vectorfield=None, plot_streamplot=None,
51
        plot_equilpoints=True, plot_separatrices=True, ax=None,
52
        suppress_warnings=False, title=None, **kwargs
53
):
54
    """Plot phase plane diagram.
55

56
    This function plots phase plane data, including vector fields, stream
57
    lines, equilibrium points, and contour curves.
58
    If none of plot_streamlines, plot_vectorfield, or plot_streamplot are
59
    set, then plot_streamlines is used by default.
60

61
    Parameters
62
    ----------
63
    sys : `NonlinearIOSystem` or callable(t, x, ...)
64
        I/O system or function used to generate phase plane data. If a
65
        function is given, the remaining arguments are drawn from the
66
        `params` keyword.
67
    pointdata : list or 2D array
68
        List of the form [xmin, xmax, ymin, ymax] describing the
69
        boundaries of the phase plot or an array of shape (N, 2)
70
        giving points of at which to plot the vector field.
71
    timedata : int or list of int
72
        Time to simulate each streamline.  If a list is given, a different
73
        time can be used for each initial condition in `pointdata`.
74
    gridtype : str, optional
75
        The type of grid to use for generating initial conditions:
76
        'meshgrid' (default) generates a mesh of initial conditions within
77
        the specified boundaries, 'boxgrid' generates initial conditions
78
        along the edges of the boundary, 'circlegrid' generates a circle of
79
        initial conditions around each point in point data.
80
    gridspec : list, optional
81
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
82
        size of the grid in the x and y axes on which to generate points.
83
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
84
        specifying the radius and number of points around each point in the
85
        `pointdata` array.
86
    params : dict, optional
87
        Parameters to pass to system. For an I/O system, `params` should be
88
        a dict of parameters and values. For a callable, `params` should be
89
        dict with key 'args' and value given by a tuple (passed to callable).
90
    color : matplotlib color spec, optional
91
        Plot all elements in the given color (use ``plot_<element>`` =
92
        {'color': c} to set the color in one element of the phase
93
        plot (equilpoints, separatrices, streamlines, etc).
94
    ax : `matplotlib.axes.Axes`, optional
95
        The matplotlib axes to draw the figure on.  If not specified and
96
        the current figure has a single axes, that axes is used.
97
        Otherwise, a new figure is created.
98

99
    Returns
100
    -------
101
    cplt : `ControlPlot` object
102
        Object containing the data that were plotted.  See `ControlPlot`
103
        for more detailed information.
104
    cplt.lines : array of list of `matplotlib.lines.Line2D`
105
        Array of list of `matplotlib.artist.Artist` objects:
106

107
            - lines[0] = list of Line2D objects (streamlines, separatrices).
108
            - lines[1] = Quiver object (vector field arrows).
109
            - lines[2] = list of Line2D objects (equilibrium points).
110
            - lines[3] = StreamplotSet object (lines with arrows).
111

112
    cplt.axes : 2D array of `matplotlib.axes.Axes`
113
        Axes for each subplot.
114
    cplt.figure : `matplotlib.figure.Figure`
115
        Figure containing the plot.
116

117
    Other Parameters
118
    ----------------
119
    arrows : int
120
        Set the number of arrows to plot along the streamlines. The default
121
        value can be set in `config.defaults['phaseplot.arrows']`.
122
    arrow_size : float
123
        Set the size of arrows to plot along the streamlines.  The default
124
        value can be set in `config.defaults['phaseplot.arrow_size']`.
125
    arrow_style : matplotlib patch
126
        Set the style of arrows to plot along the streamlines.  The default
127
        value can be set in `config.defaults['phaseplot.arrow_style']`.
128
    dir : str, optional
129
        Direction to draw streamlines: 'forward' to flow forward in time
130
        from the reference points, 'reverse' to flow backward in time, or
131
        'both' to flow both forward and backward.  The amount of time to
132
        simulate in each direction is given by the `timedata` argument.
133
    plot_streamlines : bool or dict, optional
134
        If then plot streamlines based on the pointdata and gridtype.  If set
135
        to a dict, pass on the key-value pairs in the dict as keywords to
136
        `streamlines`.
137
    plot_vectorfield : bool or dict, optional
138
        If then plot the vector field based on the pointdata and gridtype.
139
        If set to a dict, pass on the key-value pairs in the dict as keywords
140
        to `phaseplot.vectorfield`.
141
    plot_streamplot : bool or dict, optional
142
        If then use :func:`matplotlib.axes.Axes.streamplot` function
143
        to plot the streamlines.  If set to a dict, pass on the key-value
144
        pairs in the dict as keywords to :func:`~control.phaseplot.streamplot`.
145
    plot_equilpoints : bool or dict, optional
146
        If True (default) then plot equilibrium points based in the phase
147
        plot boundary. If set to a dict, pass on the key-value pairs in the
148
        dict as keywords to `phaseplot.equilpoints`.
149
    plot_separatrices : bool or dict, optional
150
        If True (default) then plot separatrices starting from each
151
        equilibrium point.  If set to a dict, pass on the key-value pairs
152
        in the dict as keywords to `phaseplot.separatrices`.
153
    rcParams : dict
154
        Override the default parameters used for generating plots.
155
        Default is set by `config.defaults['ctrlplot.rcParams']`.
156
    suppress_warnings : bool, optional
157
        If set to True, suppress warning messages in generating trajectories.
158
    title : str, optional
159
        Set the title of the plot.  Defaults to plot type and system name(s).
160

161
    """
162
    if (
9✔
163
        plot_streamlines is None
164
        and plot_vectorfield is None
165
        and plot_streamplot is None
166
    ):
167
        plot_streamlines = True
9✔
168

169
    if plot_streamplot and not plot_streamlines and not plot_vectorfield:
9✔
170
        gridspec = gridspec or [25, 25]
9✔
171

172
    # Process arguments
173
    params = kwargs.get('params', None)
9✔
174
    sys = _create_system(sys, params)
9✔
175
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
9✔
176
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
177

178
    # Create axis if needed
179
    user_ax = ax
9✔
180
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
9✔
181

182
    # Create copy of kwargs for later checking to find unused arguments
183
    initial_kwargs = dict(kwargs)
9✔
184

185
    # Utility function to create keyword arguments
186
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
9✔
187
        new_kwargs = dict(global_kwargs)
9✔
188
        new_kwargs.update(other_kwargs)
9✔
189
        if isinstance(local_kwargs, dict):
9✔
190
            new_kwargs.update(local_kwargs)
9✔
191
        return new_kwargs
9✔
192

193
    # Create list for storing outputs
194
    out = np.array([[], None, None], dtype=object)
9✔
195

196
    # Plot out the main elements
197
    if plot_streamlines:
9✔
198
        kwargs_local = _create_kwargs(
9✔
199
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
200
            ax=ax)
201
        out[0] += streamlines(
9✔
202
            sys, pointdata, timedata, _check_kwargs=False,
203
            suppress_warnings=suppress_warnings, **kwargs_local)
204

205
        # Get rid of keyword arguments handled by streamlines
206
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
9✔
207
                   'dir', 'params']:
208
            initial_kwargs.pop(kw, None)
9✔
209

210
    # Reset the gridspec for the remaining commands, if needed
211
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
9✔
212
        gridspec = None
×
213

214
    if plot_separatrices:
9✔
215
        kwargs_local = _create_kwargs(
9✔
216
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
217
        out[0] += separatrices(
9✔
218
            sys, pointdata, _check_kwargs=False, **kwargs_local)
219

220
        # Get rid of keyword arguments handled by separatrices
221
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
9✔
222
            initial_kwargs.pop(kw, None)
9✔
223

224
    if plot_vectorfield:
9✔
225
        kwargs_local = _create_kwargs(
×
226
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
227
        out[1] = vectorfield(
×
228
            sys, pointdata, _check_kwargs=False, **kwargs_local)
229

230
        # Get rid of keyword arguments handled by vectorfield
231
        for kw in ['color', 'params']:
×
232
            initial_kwargs.pop(kw, None)
×
233

234
    if plot_streamplot:
9✔
235
        kwargs_local = _create_kwargs(
9✔
236
            kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
237
        streamplot(
9✔
238
            sys, pointdata, _check_kwargs=False, **kwargs_local)
239

240
        # Get rid of keyword arguments handled by streamplot
241
        for kw in ['color', 'params']:
9✔
242
            initial_kwargs.pop(kw, None)
9✔
243

244
    if plot_equilpoints:
9✔
245
        kwargs_local = _create_kwargs(
9✔
246
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
247
        out[2] = equilpoints(
9✔
248
            sys, pointdata, _check_kwargs=False, **kwargs_local)
249

250
        # Get rid of keyword arguments handled by equilpoints
251
        for kw in ['params']:
9✔
252
            initial_kwargs.pop(kw, None)
9✔
253

254
    # Make sure all keyword arguments were used
255
    if initial_kwargs:
9✔
256
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
9✔
257

258
    if user_ax is None:
9✔
259
        if title is None:
9✔
260
            title = f"Phase portrait for {sys.name}"
9✔
261
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
9✔
262
        ax.set_xlabel(sys.state_labels[0])
9✔
263
        ax.set_ylabel(sys.state_labels[1])
9✔
264
        plt.tight_layout()
9✔
265

266
    return ControlPlot(out, ax, fig)
9✔
267

268

269
def vectorfield(
9✔
270
        sys, pointdata, gridspec=None, ax=None, suppress_warnings=False,
271
        _check_kwargs=True, **kwargs):
272
    """Plot a vector field in the phase plane.
273

274
    This function plots a vector field for a two-dimensional state
275
    space system.
276

277
    Parameters
278
    ----------
279
    sys : `NonlinearIOSystem` or callable(t, x, ...)
280
        I/O system or function used to generate phase plane data.  If a
281
        function is given, the remaining arguments are drawn from the
282
        `params` keyword.
283
    pointdata : list or 2D array
284
        List of the form [xmin, xmax, ymin, ymax] describing the
285
        boundaries of the phase plot or an array of shape (N, 2)
286
        giving points of at which to plot the vector field.
287
    gridtype : str, optional
288
        The type of grid to use for generating initial conditions:
289
        'meshgrid' (default) generates a mesh of initial conditions within
290
        the specified boundaries, 'boxgrid' generates initial conditions
291
        along the edges of the boundary, 'circlegrid' generates a circle of
292
        initial conditions around each point in point data.
293
    gridspec : list, optional
294
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
295
        size of the grid in the x and y axes on which to generate points.
296
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
297
        specifying the radius and number of points around each point in the
298
        `pointdata` array.
299
    params : dict or list, optional
300
        Parameters to pass to system. For an I/O system, `params` should be
301
        a dict of parameters and values. For a callable, `params` should be
302
        dict with key 'args' and value given by a tuple (passed to callable).
303
    color : matplotlib color spec, optional
304
        Plot the vector field in the given color.
305
    ax : `matplotlib.axes.Axes`, optional
306
        Use the given axes for the plot, otherwise use the current axes.
307

308
    Returns
309
    -------
310
    out : Quiver
311

312
    Other Parameters
313
    ----------------
314
    rcParams : dict
315
        Override the default parameters used for generating plots.
316
        Default is set by `config.defaults['ctrlplot.rcParams']`.
317
    suppress_warnings : bool, optional
318
        If set to True, suppress warning messages in generating trajectories.
319

320
    """
321
    # Process keywords
322
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
323

324
    # Get system parameters
325
    params = kwargs.pop('params', None)
9✔
326

327
    # Create system from callable, if needed
328
    sys = _create_system(sys, params)
9✔
329

330
    # Determine the points on which to generate the vector field
331
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
332

333
    # Create axis if needed
334
    if ax is None:
9✔
335
        ax = plt.gca()
9✔
336

337
    # Set the plotting limits
338
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
339

340
    # Figure out the color to use
341
    color = _get_color(kwargs, ax=ax)
9✔
342

343
    # Make sure all keyword arguments were processed
344
    if _check_kwargs and kwargs:
9✔
345
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
346

347
    # Generate phase plane (quiver) data
348
    vfdata = np.zeros((points.shape[0], 4))
9✔
349
    sys._update_params(params)
9✔
350
    for i, x in enumerate(points):
9✔
351
        vfdata[i, :2] = x
9✔
352
        vfdata[i, 2:] = sys._rhs(0, x, np.zeros(sys.ninputs))
9✔
353

354
    with plt.rc_context(rcParams):
9✔
355
        out = ax.quiver(
9✔
356
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
357
            angles='xy', color=color)
358

359
    return out
9✔
360

361

362
def streamplot(
9✔
363
        sys, pointdata, gridspec=None, ax=None, vary_color=False,
364
        vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
365
        _check_kwargs=True, **kwargs):
366
    """Plot a vector field in the phase plane.
367

368
    This function plots a vector field for a two-dimensional state
369
    space system.
370

371
    Parameters
372
    ----------
373
    sys : `NonlinearIOSystem` or callable(t, x, ...)
374
        I/O system or function used to generate phase plane data.  If a
375
        function is given, the remaining arguments are drawn from the
376
        `params` keyword.
377
    pointdata : list or 2D array
378
        List of the form [xmin, xmax, ymin, ymax] describing the
379
        boundaries of the phase plot or an array of shape (N, 2)
380
        giving points from which to make the streamplot. In the latter case,
381
        the points lie on a grid like that generated by `meshgrid`.
382
    gridspec : list, optional
383
        Specifies the size of the grid in the x and y axes on which to
384
        generate points.
385
    params : dict or list, optional
386
        Parameters to pass to system. For an I/O system, `params` should be
387
        a dict of parameters and values. For a callable, `params` should be
388
        dict with key 'args' and value given by a tuple (passed to callable).
389
    color : matplotlib color spec, optional
390
        Plot the vector field in the given color.
391
    vary_color : bool, optional
392
        If set to True, vary the color of the streamlines based on the magnitude
393
    vary_linewidth : bool, optional.
394
        If set to True, vary the linewidth of the streamlines based on the magnitude.
395
    cmap : str or Colormap, optional
396
        Colormap to use for varying the color of the streamlines.
397
    norm : `matplotlib.colors.Normalize`, optional
398
        An instance of Normalize to use for scaling the colormap and linewidths.
399
    ax : `matplotlib.axes.Axes`, optional
400
        Use the given axes for the plot, otherwise use the current axes.
401

402
    Returns
403
    -------
404
    out : StreamplotSet
405

406
    Other Parameters
407
    ----------------
408
    rcParams : dict
409
        Override the default parameters used for generating plots.
410
        Default is set by `config.default['ctrlplot.rcParams']`.
411
    suppress_warnings : bool, optional
412
        If set to True, suppress warning messages in generating trajectories.
413

414
    """
415
    # Process keywords
416
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
417

418
    # Get system parameters
419
    params = kwargs.pop('params', None)
9✔
420

421
    # Create system from callable, if needed
422
    sys = _create_system(sys, params)
9✔
423

424
    # Determine the points on which to generate the streamplot field
425
    points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
9✔
426

427
    # attempt to recover the grid by counting the jumps in xvals
428
    if gridspec is None:
9✔
429
        nrows = np.sum(np.diff(points[:, 0]) < 0) + 1
×
430
        ncols = points.shape[0] // nrows
×
431
        if nrows * ncols != points.shape[0]:
×
432
            raise ValueError("Could not recover grid from points.")
×
433
        gridspec = [nrows, ncols]
×
434

435
    grid_arr_shape = gridspec[::-1]
9✔
436
    xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
9✔
437

438
    # Create axis if needed
439
    if ax is None:
9✔
440
        ax = plt.gca()
9✔
441

442
    # Set the plotting limits
443
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
444

445
    # Figure out the color to use
446
    color = _get_color(kwargs, ax=ax)
9✔
447

448
    # Make sure all keyword arguments were processed
449
    if _check_kwargs and kwargs:
9✔
450
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
451

452
    # Generate phase plane (quiver) data
453
    sys._update_params(params)
9✔
454
    us_flat, vs_flat = np.transpose([sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
9✔
455
    us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
9✔
456

457
    magnitudes = np.linalg.norm([us, vs], axis=0)
9✔
458
    norm = norm or mpl.colors.Normalize()
9✔
459
    normalized = norm(magnitudes)
9✔
460
    cmap =  plt.get_cmap(cmap)
9✔
461

462
    with plt.rc_context(rcParams):
9✔
463
        default_lw = plt.rcParams['lines.linewidth']
9✔
464
        min_lw, max_lw = 0.25*default_lw, 2*default_lw
9✔
465
        linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
9✔
466
        color = magnitudes if vary_color else color
9✔
467

468
        out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths, cmap=cmap, norm=norm)
9✔
469

470
    return out
9✔
471

472
def streamlines(
9✔
473
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
474
        ax=None, _check_kwargs=True, suppress_warnings=False, **kwargs):
475
    """Plot stream lines in the phase plane.
476

477
    This function plots stream lines for a two-dimensional state space
478
    system.
479

480
    Parameters
481
    ----------
482
    sys : `NonlinearIOSystem` or callable(t, x, ...)
483
        I/O system or function used to generate phase plane data.  If a
484
        function is given, the remaining arguments are drawn from the
485
        `params` keyword.
486
    pointdata : list or 2D array
487
        List of the form [xmin, xmax, ymin, ymax] describing the
488
        boundaries of the phase plot or an array of shape (N, 2)
489
        giving points of at which to plot the vector field.
490
    timedata : int or list of int
491
        Time to simulate each streamline.  If a list is given, a different
492
        time can be used for each initial condition in `pointdata`.
493
    gridtype : str, optional
494
        The type of grid to use for generating initial conditions:
495
        'meshgrid' (default) generates a mesh of initial conditions within
496
        the specified boundaries, 'boxgrid' generates initial conditions
497
        along the edges of the boundary, 'circlegrid' generates a circle of
498
        initial conditions around each point in point data.
499
    gridspec : list, optional
500
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
501
        size of the grid in the x and y axes on which to generate points.
502
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
503
        specifying the radius and number of points around each point in the
504
        `pointdata` array.
505
    dir : str, optional
506
        Direction to draw streamlines: 'forward' to flow forward in time
507
        from the reference points, 'reverse' to flow backward in time, or
508
        'both' to flow both forward and backward.  The amount of time to
509
        simulate in each direction is given by the `timedata` argument.
510
    params : dict or list, optional
511
        Parameters to pass to system. For an I/O system, `params` should be
512
        a dict of parameters and values. For a callable, `params` should be
513
        dict with key 'args' and value given by a tuple (passed to callable).
514
    color : str
515
        Plot the streamlines in the given color.
516
    ax : `matplotlib.axes.Axes`, optional
517
        Use the given axes for the plot, otherwise use the current axes.
518

519
    Returns
520
    -------
521
    out : list of Line2D objects
522

523
    Other Parameters
524
    ----------------
525
    arrows : int
526
        Set the number of arrows to plot along the streamlines. The default
527
        value can be set in `config.defaults['phaseplot.arrows']`.
528
    arrow_size : float
529
        Set the size of arrows to plot along the streamlines.  The default
530
        value can be set in `config.defaults['phaseplot.arrow_size']`.
531
    arrow_style : matplotlib patch
532
        Set the style of arrows to plot along the streamlines.  The default
533
        value can be set in `config.defaults['phaseplot.arrow_style']`.
534
    rcParams : dict
535
        Override the default parameters used for generating plots.
536
        Default is set by `config.defaults['ctrlplot.rcParams']`.
537
    suppress_warnings : bool, optional
538
        If set to True, suppress warning messages in generating trajectories.
539

540
    """
541
    # Process keywords
542
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
543

544
    # Get system parameters
545
    params = kwargs.pop('params', None)
9✔
546

547
    # Create system from callable, if needed
548
    sys = _create_system(sys, params)
9✔
549

550
    # Parse the arrows keyword
551
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
552

553
    # Determine the points on which to generate the streamlines
554
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
9✔
555
    if dir is None:
9✔
556
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
9✔
557

558
    # Create axis if needed
559
    if ax is None:
9✔
560
        ax = plt.gca()
9✔
561

562
    # Set the axis limits
563
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
564

565
    # Figure out the color to use
566
    color = _get_color(kwargs, ax=ax)
9✔
567

568
    # Make sure all keyword arguments were processed
569
    if _check_kwargs and kwargs:
9✔
570
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
571

572
    # Create reverse time system, if needed
573
    if dir != 'forward':
9✔
574
        revsys = NonlinearIOSystem(
9✔
575
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
576
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
577
            outputs=sys.noutputs, params=sys.params)
578
    else:
579
        revsys = None
9✔
580

581
    # Generate phase plane (streamline) data
582
    out = []
9✔
583
    for i, X0 in enumerate(points):
9✔
584
        # Create the trajectory for this point
585
        timepts = _make_timepts(timedata, i)
9✔
586
        traj = _create_trajectory(
9✔
587
            sys, revsys, timepts, X0, params, dir,
588
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
589
            suppress_warnings=suppress_warnings)
590

591
        # Plot the trajectory (if there is one)
592
        if traj.shape[1] > 1:
9✔
593
            with plt.rc_context(rcParams):
9✔
594
                out += ax.plot(traj[0], traj[1], color=color)
9✔
595

596
                # Add arrows to the lines at specified intervals
597
                _add_arrows_to_line2D(
9✔
598
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
599
    return out
9✔
600

601

602
def equilpoints(
9✔
603
        sys, pointdata, gridspec=None, color='k', ax=None,
604
        _check_kwargs=True, **kwargs):
605
    """Plot equilibrium points in the phase plane.
606

607
    This function plots the equilibrium points for a planar dynamical system.
608

609
    Parameters
610
    ----------
611
    sys : `NonlinearIOSystem` or callable(t, x, ...)
612
        I/O system or function used to generate phase plane data. If a
613
        function is given, the remaining arguments are drawn from the
614
        `params` keyword.
615
    pointdata : list or 2D array
616
        List of the form [xmin, xmax, ymin, ymax] describing the
617
        boundaries of the phase plot or an array of shape (N, 2)
618
        giving points of at which to plot the vector field.
619
    gridtype : str, optional
620
        The type of grid to use for generating initial conditions:
621
        'meshgrid' (default) generates a mesh of initial conditions within
622
        the specified boundaries, 'boxgrid' generates initial conditions
623
        along the edges of the boundary, 'circlegrid' generates a circle of
624
        initial conditions around each point in point data.
625
    gridspec : list, optional
626
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
627
        size of the grid in the x and y axes on which to generate points.
628
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
629
        specifying the radius and number of points around each point in the
630
        `pointdata` array.
631
    params : dict or list, optional
632
        Parameters to pass to system. For an I/O system, `params` should be
633
        a dict of parameters and values. For a callable, `params` should be
634
        dict with key 'args' and value given by a tuple (passed to callable).
635
    color : str
636
        Plot the equilibrium points in the given color.
637
    ax : `matplotlib.axes.Axes`, optional
638
        Use the given axes for the plot, otherwise use the current axes.
639

640
    Returns
641
    -------
642
    out : list of Line2D objects
643

644
    Other Parameters
645
    ----------------
646
    rcParams : dict
647
        Override the default parameters used for generating plots.
648
        Default is set by `config.defaults['ctrlplot.rcParams']`.
649

650
    """
651
    # Process keywords
652
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
653

654
    # Get system parameters
655
    params = kwargs.pop('params', None)
9✔
656

657
    # Create system from callable, if needed
658
    sys = _create_system(sys, params)
9✔
659

660
    # Create axis if needed
661
    if ax is None:
9✔
662
        ax = plt.gca()
9✔
663

664
    # Set the axis limits
665
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
666

667
    # Determine the points on which to generate the vector field
668
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
669
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
670

671
    # Make sure all keyword arguments were processed
672
    if _check_kwargs and kwargs:
9✔
673
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
674

675
    # Search for equilibrium points
676
    equilpts = _find_equilpts(sys, points, params=params)
9✔
677

678
    # Plot the equilibrium points
679
    out = []
9✔
680
    for xeq in equilpts:
9✔
681
        with plt.rc_context(rcParams):
9✔
682
            out += ax.plot(xeq[0], xeq[1], marker='o', color=color)
9✔
683
    return out
9✔
684

685

686
def separatrices(
9✔
687
        sys, pointdata, timedata=None, gridspec=None, ax=None,
688
        _check_kwargs=True, suppress_warnings=False, **kwargs):
689
    """Plot separatrices in the phase plane.
690

691
    This function plots separatrices for a two-dimensional state space
692
    system.
693

694
    Parameters
695
    ----------
696
    sys : `NonlinearIOSystem` or callable(t, x, ...)
697
        I/O system or function used to generate phase plane data. If a
698
        function is given, the remaining arguments are drawn from the
699
        `params` keyword.
700
    pointdata : list or 2D array
701
        List of the form [xmin, xmax, ymin, ymax] describing the
702
        boundaries of the phase plot or an array of shape (N, 2)
703
        giving points of at which to plot the vector field.
704
    timedata : int or list of int
705
        Time to simulate each streamline.  If a list is given, a different
706
        time can be used for each initial condition in `pointdata`.
707
    gridtype : str, optional
708
        The type of grid to use for generating initial conditions:
709
        'meshgrid' (default) generates a mesh of initial conditions within
710
        the specified boundaries, 'boxgrid' generates initial conditions
711
        along the edges of the boundary, 'circlegrid' generates a circle of
712
        initial conditions around each point in point data.
713
    gridspec : list, optional
714
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
715
        size of the grid in the x and y axes on which to generate points.
716
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
717
        specifying the radius and number of points around each point in the
718
        `pointdata` array.
719
    params : dict or list, optional
720
        Parameters to pass to system. For an I/O system, `params` should be
721
        a dict of parameters and values. For a callable, `params` should be
722
        dict with key 'args' and value given by a tuple (passed to callable).
723
    color : matplotlib color spec, optional
724
        Plot the separatrices in the given color.  If a single color
725
        specification is given, this is used for both stable and unstable
726
        separatrices.  If a tuple is given, the first element is used as
727
        the color specification for stable separatrices and the second
728
        element for unstable separatrices.
729
    ax : `matplotlib.axes.Axes`, optional
730
        Use the given axes for the plot, otherwise use the current axes.
731

732
    Returns
733
    -------
734
    out : list of Line2D objects
735

736
    Other Parameters
737
    ----------------
738
    rcParams : dict
739
        Override the default parameters used for generating plots.
740
        Default is set by `config.defaults['ctrlplot.rcParams']`.
741
    suppress_warnings : bool, optional
742
        If set to True, suppress warning messages in generating trajectories.
743

744
    Notes
745
    -----
746
    The value of `config.defaults['separatrices_radius']` is used to set the
747
    offset from the equilibrium point to the starting point of the separatix
748
    traces, in the direction of the eigenvectors evaluated at that
749
    equilibrium point.
750

751
    """
752
    # Process keywords
753
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
754

755
    # Get system parameters
756
    params = kwargs.pop('params', None)
9✔
757

758
    # Create system from callable, if needed
759
    sys = _create_system(sys, params)
9✔
760

761
    # Parse the arrows keyword
762
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
763

764
    # Determine the initial states to use in searching for equilibrium points
765
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
766
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
767

768
    # Find the equilibrium points
769
    equilpts = _find_equilpts(sys, points, params=params)
9✔
770
    radius = config._get_param('phaseplot', 'separatrices_radius')
9✔
771

772
    # Create axis if needed
773
    if ax is None:
9✔
774
        ax = plt.gca()
9✔
775

776
    # Set the axis limits
777
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
778

779
    # Figure out the color to use for stable, unstable subspaces
780
    color = _get_color(kwargs)
9✔
781
    match color:
9✔
782
        case None:
9✔
783
            stable_color = 'r'
9✔
784
            unstable_color = 'b'
9✔
785
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
9✔
786
            pass
9✔
787
        case single_color:
9✔
788
            stable_color = unstable_color = single_color
9✔
789

790
    # Make sure all keyword arguments were processed
791
    if _check_kwargs and kwargs:
9✔
792
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
793

794
    # Create a "reverse time" system to use for simulation
795
    revsys = NonlinearIOSystem(
9✔
796
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
797
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
798
        outputs=sys.noutputs, params=sys.params)
799

800
    # Plot separatrices by flowing backwards in time along eigenspaces
801
    out = []
9✔
802
    for i, xeq in enumerate(equilpts):
9✔
803
        # Plot the equilibrium points
804
        with plt.rc_context(rcParams):
9✔
805
            out += ax.plot(xeq[0], xeq[1], marker='o', color='k')
9✔
806

807
        # Figure out the linearization and eigenvectors
808
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
9✔
809

810
        # See if we have real eigenvalues (=> evecs are meaningful)
811
        if evals[0].imag > 0:
9✔
812
            continue
9✔
813

814
        # Create default list of time points
815
        if timedata is not None:
9✔
816
            timepts = _make_timepts(timedata, i)
9✔
817

818
        # Generate the traces
819
        for j, dir in enumerate(evecs.T):
9✔
820
            # Figure out time vector if not yet computed
821
            if timedata is None:
9✔
822
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
9✔
823
                timepts = np.linspace(0, timescale)
9✔
824

825
            # Run the trajectory starting in eigenvector directions
826
            for eps in [-radius, radius]:
9✔
827
                x0 = xeq + dir * eps
9✔
828
                if evals[j].real < 0:
9✔
829
                    traj = _create_trajectory(
9✔
830
                        sys, revsys, timepts, x0, params, 'reverse',
831
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
832
                        suppress_warnings=suppress_warnings)
833
                    color = stable_color
9✔
834
                    linestyle = '--'
9✔
835
                elif evals[j].real > 0:
9✔
836
                    traj = _create_trajectory(
9✔
837
                        sys, revsys, timepts, x0, params, 'forward',
838
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
839
                        suppress_warnings=suppress_warnings)
840
                    color = unstable_color
9✔
841
                    linestyle = '-'
9✔
842

843
                # Plot the trajectory (if there is one)
844
                if traj.shape[1] > 1:
9✔
845
                    with plt.rc_context(rcParams):
9✔
846
                        out += ax.plot(
9✔
847
                            traj[0], traj[1], color=color, linestyle=linestyle)
848

849
                    # Add arrows to the lines at specified intervals
850
                    with plt.rc_context(rcParams):
9✔
851
                        _add_arrows_to_line2D(
9✔
852
                            ax, out[-1], arrow_pos, arrowstyle=arrow_style,
853
                            dir=1)
854
    return out
9✔
855

856

857
#
858
# User accessible utility functions
859
#
860

861
# Utility function to generate boxgrid (in the form needed here)
862
def boxgrid(xvals, yvals):
9✔
863
    """Generate list of points along the edge of box.
864

865
    points = boxgrid(xvals, yvals) generates a list of points that
866
    corresponds to a grid given by the cross product of the x and y values.
867

868
    Parameters
869
    ----------
870
    xvals, yvals : 1D array_like
871
        Array of points defining the points on the lower and left edges of
872
        the box.
873

874
    Returns
875
    -------
876
    grid : 2D array
877
        Array with shape (p, 2) defining the points along the edges of the
878
        box, where p is the number of points around the edge.
879

880
    """
881
    return np.array(
9✔
882
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
883
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
884
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
885
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
886
    )
887

888

889
# Utility function to generate meshgrid (in the form needed here)
890
# TODO: add examples of using grid functions directly
891
def meshgrid(xvals, yvals):
9✔
892
    """Generate list of points forming a mesh.
893

894
    points = meshgrid(xvals, yvals) generates a list of points that
895
    corresponds to a grid given by the cross product of the x and y values.
896

897
    Parameters
898
    ----------
899
    xvals, yvals : 1D array_like
900
        Array of points defining the points on the lower and left edges of
901
        the box.
902

903
    Returns
904
    -------
905
    grid : 2D array
906
        Array of points with shape (n * m, 2) defining the mesh.
907

908
    """
909
    xvals, yvals = np.meshgrid(xvals, yvals)
9✔
910
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
9✔
911
    grid[:, 0] = xvals.reshape(-1)
9✔
912
    grid[:, 1] = yvals.reshape(-1)
9✔
913

914
    return grid
9✔
915

916

917
# Utility function to generate circular grid
918
def circlegrid(centers, radius, num):
9✔
919
    """Generate list of points around a circle.
920

921
    points = circlegrid(centers, radius, num) generates a list of points
922
    that form a circle around a list of centers.
923

924
    Parameters
925
    ----------
926
    centers : 2D array_like
927
        Array of points with shape (p, 2) defining centers of the circles.
928
    radius : float
929
        Radius of the points to be generated around each center.
930
    num : int
931
        Number of points to generate around the circle.
932

933
    Returns
934
    -------
935
    grid : 2D array
936
        Array of points with shape (p * num, 2) defining the circles.
937

938
    """
939
    centers = np.atleast_2d(np.array(centers))
9✔
940
    grid = np.zeros((centers.shape[0] * num, 2))
9✔
941
    for i, center in enumerate(centers):
9✔
942
        grid[i * num: (i + 1) * num, :] = center + np.array([
9✔
943
            [radius * math.cos(theta), radius * math.sin(theta)] for
944
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
945
    return grid
9✔
946

947
#
948
# Internal utility functions
949
#
950

951
# Create a system from a callable
952
def _create_system(sys, params):
9✔
953
    if isinstance(sys, NonlinearIOSystem):
9✔
954
        if sys.nstates != 2:
9✔
955
            raise ValueError("system must be planar")
9✔
956
        return sys
9✔
957

958
    # Make sure that if params is present, it has 'args' key
959
    if params and not params.get('args', None):
9✔
960
        raise ValueError("params must be dict with key 'args'")
9✔
961

962
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
9✔
963
    _output = lambda t, x, u, params: np.array([])
9✔
964
    return NonlinearIOSystem(
9✔
965
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
966

967
# Set axis limits for the plot
968
def _set_axis_limits(ax, pointdata):
9✔
969
    # Get the current axis limits
970
    if ax.lines:
9✔
971
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
9✔
972
    else:
973
        # Nothing on the plot => always use new limits
974
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
9✔
975

976
    # Short utility function for updating axis limits
977
    def _update_limits(cur, new):
9✔
978
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
9✔
979

980
    # If we were passed a box, use that to update the limits
981
    if isinstance(pointdata, list) and len(pointdata) == 4:
9✔
982
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
9✔
983
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
9✔
984

985
    elif isinstance(pointdata, np.ndarray):
9✔
986
        pointdata = np.atleast_2d(pointdata)
9✔
987
        xlim = _update_limits(
9✔
988
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
989
        ylim = _update_limits(
9✔
990
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
991

992
    # Keep track of the largest dimension on the plot
993
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
9✔
994

995
    # Set the new limits
996
    ax.autoscale(enable=True, axis='x', tight=True)
9✔
997
    ax.autoscale(enable=True, axis='y', tight=True)
9✔
998
    ax.set_xlim(xlim)
9✔
999
    ax.set_ylim(ylim)
9✔
1000

1001
    return xlim, ylim, maxlim
9✔
1002

1003

1004
# Find equilibrium points
1005
def _find_equilpts(sys, points, params=None):
9✔
1006
    equilpts = []
9✔
1007
    for i, x0 in enumerate(points):
9✔
1008
        # Look for an equilibrium point near this point
1009
        xeq, ueq = find_operating_point(sys, x0, 0, params=params)
9✔
1010

1011
        if xeq is None:
9✔
1012
            continue            # didn't find anything
9✔
1013

1014
        # See if we have already found this point
1015
        seen = False
9✔
1016
        for x in equilpts:
9✔
1017
            if np.allclose(np.array(x), xeq):
9✔
1018
                seen = True
9✔
1019
        if seen:
9✔
1020
            continue
9✔
1021

1022
        # Save a new point
1023
        equilpts += [xeq.tolist()]
9✔
1024

1025
    return equilpts
9✔
1026

1027

1028
def _make_points(pointdata, gridspec, gridtype):
9✔
1029
    # Check to see what type of data we got
1030
    if isinstance(pointdata, np.ndarray) and gridtype is None:
9✔
1031
        pointdata = np.atleast_2d(pointdata)
9✔
1032
        if pointdata.shape[1] == 2:
9✔
1033
            # Given a list of points => no action required
1034
            return pointdata, None
9✔
1035

1036
    # Utility function to parse (and check) input arguments
1037
    def _parse_args(defsize):
9✔
1038
        if gridspec is None:
9✔
1039
            return defsize
9✔
1040

1041
        elif not isinstance(gridspec, (list, tuple)) or \
9✔
1042
             len(gridspec) != len(defsize):
1043
            raise ValueError("invalid grid specification")
9✔
1044

1045
        return gridspec
9✔
1046

1047
    # Generate points based on grid type
1048
    match gridtype:
9✔
1049
        case 'boxgrid' | None:
9✔
1050
            gridspec = _parse_args([6, 4])
9✔
1051
            points = boxgrid(
9✔
1052
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1053
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1054

1055
        case 'meshgrid':
9✔
1056
            gridspec = _parse_args([9, 6])
9✔
1057
            points = meshgrid(
9✔
1058
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1059
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1060

1061
        case 'circlegrid':
9✔
1062
            gridspec = _parse_args((0.5, 10))
9✔
1063
            if isinstance(pointdata, np.ndarray):
9✔
1064
                # Create circles around each point
1065
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
9✔
1066
            else:
1067
                # Create circle around center of the plot
1068
                points = circlegrid(
9✔
1069
                    np.array(
1070
                        [(pointdata[0] + pointdata[1]) / 2,
1071
                         (pointdata[0] + pointdata[1]) / 2]),
1072
                    gridspec[0], gridspec[1])
1073

1074
        case _:
9✔
1075
            raise ValueError(f"unknown grid type '{gridtype}'")
9✔
1076

1077
    return points, gridspec
9✔
1078

1079

1080
def _parse_arrow_keywords(kwargs):
9✔
1081
    # Get values for params (and pop from list to allow keyword use in plot)
1082
    # TODO: turn this into a utility function (shared with nyquist_plot?)
1083
    arrows = config._get_param(
9✔
1084
        'phaseplot', 'arrows', kwargs, None, pop=True)
1085
    arrow_size = config._get_param(
9✔
1086
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
1087
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
9✔
1088

1089
    # Parse the arrows keyword
1090
    if not arrows:
9✔
1091
        arrow_pos = []
×
1092
    elif isinstance(arrows, int):
9✔
1093
        N = arrows
9✔
1094
        # Space arrows out, starting midway along each "region"
1095
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
9✔
1096
    elif isinstance(arrows, (list, np.ndarray)):
×
1097
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
1098
    else:
1099
        raise ValueError("unknown or unsupported arrow location")
×
1100

1101
    # Set the arrow style
1102
    if arrow_style is None:
9✔
1103
        arrow_style = mpl.patches.ArrowStyle(
9✔
1104
            'simple', head_width=int(2 * arrow_size / 3),
1105
            head_length=arrow_size)
1106

1107
    return arrow_pos, arrow_style
9✔
1108

1109

1110
# TODO: move to ctrlplot?
1111
def _create_trajectory(
9✔
1112
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
1113
        gridtype=None, gridspec=None, xlim=None, ylim=None):
1114
    # Compute the forward trajectory
1115
    if dir == 'forward' or dir == 'both':
9✔
1116
        fwdresp = input_output_response(
9✔
1117
            sys, timepts, X0=X0, params=params, ignore_errors=True)
1118
        if not fwdresp.success and not suppress_warnings:
9✔
1119
            warnings.warn(f"{X0=}, {fwdresp.message}")
9✔
1120

1121
    # Compute the reverse trajectory
1122
    if dir == 'reverse' or dir == 'both':
9✔
1123
        revresp = input_output_response(
9✔
1124
            revsys, timepts, X0=X0, params=params, ignore_errors=True)
1125
        if not revresp.success and not suppress_warnings:
9✔
1126
            warnings.warn(f"{X0=}, {revresp.message}")
×
1127

1128
    # Create the trace to plot
1129
    if dir == 'forward':
9✔
1130
        traj = fwdresp.states
9✔
1131
    elif dir == 'reverse':
9✔
1132
        traj = revresp.states[:, ::-1]
9✔
1133
    elif dir == 'both':
9✔
1134
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
9✔
1135

1136
    # Remove points outside the window (keep first point beyond boundary)
1137
    inrange = np.asarray(
9✔
1138
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
1139
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
1140
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
9✔
1141
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
9✔
1142

1143
    return traj[:, inrange]
9✔
1144

1145

1146
def _make_timepts(timepts, i):
9✔
1147
    if timepts is None:
9✔
1148
        return np.linspace(0, 1)
9✔
1149
    elif isinstance(timepts, (int, float)):
9✔
1150
        return np.linspace(0, timepts)
9✔
1151
    elif timepts.ndim == 2:
×
1152
        return timepts[i]
×
1153
    return timepts
×
1154

1155

1156
#
1157
# Legacy phase plot function
1158
#
1159
# Author: Richard Murray
1160
# Date: 24 July 2011, converted from MATLAB version (2002); based on
1161
# a version by Kristi Morgansen
1162
#
1163
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
9✔
1164
               lingrid=None, lintime=None, logtime=None, timepts=None,
1165
               parms=None, params=(), tfirst=False, verbose=True):
1166

1167
    """(legacy) Phase plot for 2D dynamical systems.
1168

1169
    .. deprecated:: 0.10.1
1170
        This function is deprecated; use `phase_plane_plot` instead.
1171

1172
    Produces a vector field or stream line plot for a planar system.  This
1173
    function has been replaced by the `phase_plane_map` and
1174
    `phase_plane_plot` functions.
1175

1176
    Call signatures:
1177
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
1178
      phase_plot(func, X, Y, scale, ...) - scale arrows
1179
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
1180
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
1181
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
1182
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
1183

1184
    Parameters
1185
    ----------
1186
    func : callable(x, t, ...)
1187
        Computes the time derivative of y (compatible with odeint).  The
1188
        function should be the same for as used for `scipy.integrate`.
1189
        Namely, it should be a function of the form dx/dt = F(t, x) that
1190
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1191
        dimension 2.
1192
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1193
        Two 3-element sequences specifying x and y coordinates of a
1194
        grid.  These arguments are passed to linspace and meshgrid to
1195
        generate the points at which the vector field is plotted.  If
1196
        absent (or None), the vector field is not plotted.
1197
    scale: float, optional
1198
        Scale size of arrows; default = 1
1199
    X0: ndarray of initial conditions, optional
1200
        List of initial conditions from which streamlines are plotted.
1201
        Each initial condition should be a pair of numbers.
1202
    T: array_like or number, optional
1203
        Length of time to run simulations that generate streamlines.
1204
        If a single number, the same simulation time is used for all
1205
        initial conditions.  Otherwise, should be a list of length
1206
        len(X0) that gives the simulation time for each initial
1207
        condition.  Default value = 50.
1208
    lingrid : integer or 2-tuple of integers, optional
1209
        Argument is either N or (N, M).  If X0 is given and X, Y are
1210
        missing, a grid of arrows is produced using the limits of the
1211
        initial conditions, with N grid points in each dimension or N grid
1212
        points in x and M grid points in y.
1213
    lintime : integer or tuple (integer, float), optional
1214
        If a single integer N is given, draw N arrows using equally space
1215
        time points.  If a tuple (N, lambda) is given, draw N arrows using
1216
        exponential time constant lambda
1217
    timepts : array_like, optional
1218
        Draw arrows at the given list times [t1, t2, ...]
1219
    tfirst : bool, optional
1220
        If True, call `func` with signature ``func(t, x, ...)``.
1221
    params: tuple, optional
1222
        List of parameters to pass to vector field: ``func(x, t, *params)``.
1223

1224
    See Also
1225
    --------
1226
    box_grid
1227

1228
    """
1229
    # Generate a deprecation warning
1230
    warnings.warn(
9✔
1231
        "phase_plot() is deprecated; use phase_plane_plot() instead",
1232
        FutureWarning)
1233

1234
    #
1235
    # Figure out ranges for phase plot (argument processing)
1236
    #
1237
    #! TODO: need to add error checking to arguments
1238
    #! TODO: think through proper action if multiple options are given
1239
    #
1240
    autoFlag = False
9✔
1241
    logtimeFlag = False
9✔
1242
    timeptsFlag = False
9✔
1243
    Narrows = 0
9✔
1244

1245
    # Get parameters to pass to function
1246
    if parms:
9✔
1247
        warnings.warn(
9✔
1248
            "keyword 'parms' is deprecated; use 'params'", FutureWarning)
1249
        if params:
9✔
1250
            raise ControlArgument("duplicate keywords 'parms' and 'params'")
×
1251
        else:
1252
            params = parms
9✔
1253

1254
    if lingrid is not None:
9✔
1255
        autoFlag = True
9✔
1256
        Narrows = lingrid
9✔
1257
        if (verbose):
9✔
1258
            print('Using auto arrows\n')
×
1259

1260
    elif logtime is not None:
9✔
1261
        logtimeFlag = True
9✔
1262
        Narrows = logtime[0]
9✔
1263
        timefactor = logtime[1]
9✔
1264
        if (verbose):
9✔
1265
            print('Using logtime arrows\n')
×
1266

1267
    elif timepts is not None:
9✔
1268
        timeptsFlag = True
9✔
1269
        Narrows = len(timepts)
9✔
1270

1271
    # Figure out the set of points for the quiver plot
1272
    #! TODO: Add sanity checks
1273
    elif X is not None and Y is not None:
9✔
1274
        x1, x2 = np.meshgrid(
9✔
1275
            np.linspace(X[0], X[1], X[2]),
1276
            np.linspace(Y[0], Y[1], Y[2]))
1277
        Narrows = len(x1)
9✔
1278

1279
    else:
1280
        # If we weren't given any grid points, don't plot arrows
1281
        Narrows = 0
9✔
1282

1283
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
9✔
1284
        # Now calculate the vector field at those points
1285
        (nr,nc) = x1.shape
9✔
1286
        dx = np.empty((nr, nc, 2))
9✔
1287
        for i in range(nr):
9✔
1288
            for j in range(nc):
9✔
1289
                if tfirst:
9✔
1290
                    dx[i, j, :] = np.squeeze(
×
1291
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1292
                else:
1293
                    dx[i, j, :] = np.squeeze(
9✔
1294
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1295

1296
        # Plot the quiver plot
1297
        #! TODO: figure out arguments to make arrows show up correctly
1298
        if scale is None:
9✔
1299
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1300
        elif (scale != 0):
9✔
1301
            plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
9✔
1302
                       dx[:,:,1]*np.abs(scale), angles='xy')
1303
            #! TODO: optimize parameters for arrows
1304
            #! TODO: figure out arguments to make arrows show up correctly
1305
            # xy = plt.quiver(...)
1306
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1307

1308
        #! TODO: Tweak the shape of the plot
1309
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1310
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1311
        plt.xlabel('x1'); plt.ylabel('x2')
9✔
1312

1313
    # See if we should also generate the streamlines
1314
    if X0 is None or len(X0) == 0:
9✔
1315
        return
9✔
1316

1317
    # Convert initial conditions to a numpy array
1318
    X0 = np.array(X0)
9✔
1319
    (nr, nc) = np.shape(X0)
9✔
1320

1321
    # Generate some empty matrices to keep arrow information
1322
    x1 = np.empty((nr, Narrows))
9✔
1323
    x2 = np.empty((nr, Narrows))
9✔
1324
    dx = np.empty((nr, Narrows, 2))
9✔
1325

1326
    # See if we were passed a simulation time
1327
    if T is None:
9✔
1328
        T = 50
9✔
1329

1330
    # Parse the time we were passed
1331
    TSPAN = T
9✔
1332
    if isinstance(T, (int, float)):
9✔
1333
        TSPAN = np.linspace(0, T, 100)
9✔
1334

1335
    # Figure out the limits for the plot
1336
    if scale is None:
9✔
1337
        # Assume that the current axis are set as we want them
1338
        alim = plt.axis()
×
1339
        xmin = alim[0]; xmax = alim[1]
×
1340
        ymin = alim[2]; ymax = alim[3]
×
1341
    else:
1342
        # Use the maximum extent of all trajectories
1343
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
9✔
1344
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
9✔
1345

1346
    # Generate the streamlines for each initial condition
1347
    for i in range(nr):
9✔
1348
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
9✔
1349
        time = TSPAN
9✔
1350

1351
        plt.plot(state[:,0], state[:,1])
9✔
1352
        #! TODO: add back in colors for stream lines
1353
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1354
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1355

1356
        # Plot arrows if quiver parameters were 'auto'
1357
        if autoFlag or logtimeFlag or timeptsFlag:
9✔
1358
            # Compute the locations of the arrows
1359
            #! TODO: check this logic to make sure it works in python
1360
            for j in range(Narrows):
9✔
1361

1362
                # Figure out starting index; headless arrows start at 0
1363
                k = -1 if scale is None else 0
9✔
1364

1365
                # Figure out what time index to use for the next point
1366
                if autoFlag:
9✔
1367
                    # Use a linear scaling based on ODE time vector
1368
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1369
                elif logtimeFlag:
9✔
1370
                    # Use an exponential time vector
1371
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1372
                    tarr = _find(time < (j-k) / timefactor)
9✔
1373
                    tind = tarr[-1] if len(tarr) else 0
9✔
1374
                elif timeptsFlag:
9✔
1375
                    # Use specified time points
1376
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1377
                    tarr = _find(time < timepts[j])
9✔
1378
                    tind = tarr[-1] if len(tarr) else 0
9✔
1379

1380
                # For tailless arrows, skip the first point
1381
                if tind == 0 and scale is None:
9✔
1382
                    continue
×
1383

1384
                # Figure out the arrow at this point on the curve
1385
                x1[i,j] = state[tind, 0]
9✔
1386
                x2[i,j] = state[tind, 1]
9✔
1387

1388
                # Skip arrows outside of initial condition box
1389
                if (scale is not None or
9✔
1390
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1391
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1392
                    if tfirst:
9✔
1393
                        pass
×
1394
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1395
                    else:
1396
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
9✔
1397
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
9✔
1398
                else:
1399
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1400

1401
    # Set the plot shape before plotting arrows to avoid warping
1402
    # a=gca
1403
    # if (scale != None):
1404
    #     set(a,'DataAspectRatio', [1,1,1])
1405
    # if (xmin != xmax and ymin != ymax):
1406
    #     plt.axis([xmin, xmax, ymin, ymax])
1407
    # set(a, 'Box', 'on')
1408

1409
    # Plot arrows on the streamlines
1410
    if scale is None and Narrows > 0:
9✔
1411
        # Use a tailless arrow
1412
        #! TODO: figure out arguments to make arrows show up correctly
1413
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1414
    elif scale != 0 and Narrows > 0:
9✔
1415
        plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
9✔
1416
                   angles='xy')
1417
        #! TODO: figure out arguments to make arrows show up correctly
1418
        # xy = plt.quiver(...)
1419
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1420
        # set(xy, 'AutoScale', 'off')
1421
        # set(xy, 'AutoScaleFactor', 0)
1422

1423
    if scale < 0:
9✔
1424
        plt.plot(x1, x2, 'b.');        # add dots at base
×
1425
        # bp = plt.plot(...)
1426
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1427

1428

1429
# Utility function for generating initial conditions around a box
1430
def box_grid(xlimp, ylimp):
9✔
1431
    """Generate list of points on edge of box.
1432

1433
    .. deprecated:: 0.10.0
1434
        Use `phaseplot.boxgrid` instead.
1435

1436
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1437
    list of points that correspond to a uniform grid at the end of the
1438
    box defined by the corners [xmin ymin] and [xmax ymax].
1439

1440
    """
1441

1442
    # Generate a deprecation warning
1443
    warnings.warn(
×
1444
        "box_grid() is deprecated; use phaseplot.boxgrid() instead",
1445
        FutureWarning)
1446

1447
    return boxgrid(
×
1448
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1449
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1450

1451

1452
# TODO: rename to something more useful (or remove??)
1453
def _find(condition):
9✔
1454
    """Returns indices where ravel(a) is true.
1455

1456
    Private implementation of deprecated `matplotlib.mlab.find`.
1457

1458
    """
1459
    return np.nonzero(np.ravel(condition))[0]
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc