• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 13092133168

01 Feb 2025 08:24PM UTC coverage: 94.63% (-0.03%) from 94.659%
13092133168

Pull #1112

github

web-flow
Merge c0e4cb425 into 2cb0520b6
Pull Request #1112: Use matplotlibs streamplot function for phase_plane_plot

9709 of 10260 relevant lines covered (94.63%)

8.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.95
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Author: Richard M. Murray
4
# Date:   23 Mar 2024 (legacy version information below)
5
#
6
# TODO
7
# * Allow multiple timepoints (and change timespec name to T?)
8
# * Update linestyles (color -> linestyle?)
9
# * Check for keyword compatibility with other plot routines
10
# * Set up configuration parameters (nyquist --> phaseplot)
11

12
"""Module for generating 2D phase plane plots.
13

14
The :mod:`control.phaseplot` module contains functions for generating 2D
15
phase plots. The base function for creating phase plane portraits is
16
:func:`~control.phase_plane_plot`, which generates a phase plane portrait
17
for a 2 state I/O system (with no inputs).  In addition, several other
18
functions are available to create customized phase plane plots:
19

20
* boxgrid: Generate a list of points along the edge of a box
21
* circlegrid: Generate list of points around a circle
22
* equilpoints: Plot equilibrium points in the phase plane
23
* meshgrid: Generate a list of points forming a mesh
24
* separatrices: Plot separatrices in the phase plane
25
* streamlines: Plot stream lines in the phase plane
26
* vectorfield: Plot a vector field in the phase plane
27
* streampot: Plot streamlines using matplotlib's streamplot function
28

29
"""
30

31
import math
9✔
32
import warnings
9✔
33

34
import matplotlib as mpl
9✔
35
import matplotlib.pyplot as plt
9✔
36
import numpy as np
9✔
37
from scipy.integrate import odeint
9✔
38

39
from . import config
9✔
40
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
9✔
41
    _process_ax_keyword, _update_plot_title
42
from .exception import ControlNotImplemented
9✔
43
from .nlsys import NonlinearIOSystem, find_operating_point, \
9✔
44
    input_output_response
45

46
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
9✔
47

48
# Default values for module parameter variables
49
_phaseplot_defaults = {
9✔
50
    'phaseplot.arrows': 2,                  # number of arrows around curve
51
    'phaseplot.arrow_size': 8,              # pixel size for arrows
52
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
53
}
54

55
def phase_plane_plot(
9✔
56
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
57
        plot_streamlines=True, plot_vectorfield=False, plot_equilpoints=True,
58
        plot_separatrices=True, ax=None, suppress_warnings=False, title=None,
59
        plot_streamplot=False, **kwargs
60
):
61
    """Plot phase plane diagram.
62

63
    This function plots phase plane data, including vector fields, stream
64
    lines, equilibrium points, and contour curves.
65

66
    Parameters
67
    ----------
68
    sys : NonlinearIOSystem or callable(t, x, ...)
69
        I/O system or function used to generate phase plane data. If a
70
        function is given, the remaining arguments are drawn from the
71
        `params` keyword.
72
    pointdata : list or 2D array
73
        List of the form [xmin, xmax, ymin, ymax] describing the
74
        boundaries of the phase plot or an array of shape (N, 2)
75
        giving points of at which to plot the vector field.
76
    timedata : int or list of int
77
        Time to simulate each streamline.  If a list is given, a different
78
        time can be used for each initial condition in `pointdata`.
79
    gridtype : str, optional
80
        The type of grid to use for generating initial conditions:
81
        'meshgrid' (default) generates a mesh of initial conditions within
82
        the specified boundaries, 'boxgrid' generates initial conditions
83
        along the edges of the boundary, 'circlegrid' generates a circle of
84
        initial conditions around each point in point data.
85
    gridspec : list, optional
86
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
87
        size of the grid in the x and y axes on which to generate points.
88
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
89
        specifying the radius and number of points around each point in the
90
        `pointdata` array.
91
    params : dict, optional
92
        Parameters to pass to system. For an I/O system, `params` should be
93
        a dict of parameters and values. For a callable, `params` should be
94
        dict with key 'args' and value given by a tuple (passed to callable).
95
    color : matplotlib color spec, optional
96
        Plot all elements in the given color (use `plot_<fcn>={'color': c}`
97
        to set the color in one element of the phase plot.
98
    ax : matplotlib.axes.Axes, optional
99
        The matplotlib axes to draw the figure on.  If not specified and
100
        the current figure has a single axes, that axes is used.
101
        Otherwise, a new figure is created.
102

103
    Returns
104
    -------
105
    cplt : :class:`ControlPlot` object
106
        Object containing the data that were plotted:
107

108
          * cplt.lines: array of list of :class:`matplotlib.artist.Artist`
109
            objects:
110

111
              - lines[0] = list of Line2D objects (streamlines, separatrices).
112
              - lines[1] = Quiver object (vector field arrows).
113
              - lines[2] = list of Line2D objects (equilibrium points).
114

115
          * cplt.axes: 2D array of :class:`matplotlib.axes.Axes` for the plot.
116

117
          * cplt.figure: :class:`matplotlib.figure.Figure` containing the plot.
118

119
        See :class:`ControlPlot` for more detailed information.
120

121

122
    Other parameters
123
    ----------------
124
    dir : str, optional
125
        Direction to draw streamlines: 'forward' to flow forward in time
126
        from the reference points, 'reverse' to flow backward in time, or
127
        'both' to flow both forward and backward.  The amount of time to
128
        simulate in each direction is given by the ``timedata`` argument.
129
    plot_streamlines : bool or dict, optional
130
        If `True` (default) then plot streamlines based on the pointdata
131
        and gridtype.  If set to a dict, pass on the key-value pairs in
132
        the dict as keywords to :func:`~control.phaseplot.streamlines`.
133
    plot_vectorfield : bool or dict, optional
134
        If `True` (default) then plot the vector field based on the pointdata
135
        and gridtype.  If set to a dict, pass on the key-value pairs in
136
        the dict as keywords to :func:`~control.phaseplot.vectorfield`.
137
    plot_streamplot : bool or dict, optional
138
        If `True` then use matplotlib's streamplot function to plot the
139
        streamlines.  If set to a dict, pass on the key-value pairs in the
140
        dict as keywords to :func:`~
141
    plot_equilpoints : bool or dict, optional
142
        If `True` (default) then plot equilibrium points based in the phase
143
        plot boundary. If set to a dict, pass on the key-value pairs in the
144
        dict as keywords to :func:`~control.phaseplot.equilpoints`.
145
    plot_separatrices : bool or dict, optional
146
        If `True` (default) then plot separatrices starting from each
147
        equilibrium point.  If set to a dict, pass on the key-value pairs
148
        in the dict as keywords to :func:`~control.phaseplot.separatrices`.
149
    rcParams : dict
150
        Override the default parameters used for generating plots.
151
        Default is set by config.default['ctrlplot.rcParams'].
152
    suppress_warnings : bool, optional
153
        If set to `True`, suppress warning messages in generating trajectories.
154
    title : str, optional
155
        Set the title of the plot.  Defaults to plot type and system name(s).
156

157
    """
158
    # Process arguments
159
    params = kwargs.get('params', None)
9✔
160
    sys = _create_system(sys, params)
9✔
161
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
9✔
162
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
163

164
    # Create axis if needed
165
    user_ax = ax
9✔
166
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
9✔
167

168
    # Create copy of kwargs for later checking to find unused arguments
169
    initial_kwargs = dict(kwargs)
9✔
170
    passed_kwargs = False
9✔
171

172
    # Utility function to create keyword arguments
173
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
9✔
174
        new_kwargs = dict(global_kwargs)
9✔
175
        new_kwargs.update(other_kwargs)
9✔
176
        if isinstance(local_kwargs, dict):
9✔
177
            new_kwargs.update(local_kwargs)
9✔
178
        return new_kwargs
9✔
179

180
    # Create list for storing outputs
181
    out = np.array([[], None, None], dtype=object)
9✔
182

183
    # Plot out the main elements
184
    if plot_streamlines:
9✔
185
        kwargs_local = _create_kwargs(
9✔
186
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
187
            ax=ax)
188
        out[0] += streamlines(
9✔
189
            sys, pointdata, timedata, _check_kwargs=False,
190
            suppress_warnings=suppress_warnings, **kwargs_local)
191

192
        # Get rid of keyword arguments handled by streamlines
193
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
9✔
194
                   'dir', 'params']:
195
            initial_kwargs.pop(kw, None)
9✔
196

197
    # Reset the gridspec for the remaining commands, if needed
198
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
9✔
199
        gridspec = None
×
200

201
    if plot_separatrices:
9✔
202
        kwargs_local = _create_kwargs(
9✔
203
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
204
        out[0] += separatrices(
9✔
205
            sys, pointdata, _check_kwargs=False, **kwargs_local)
206

207
        # Get rid of keyword arguments handled by separatrices
208
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
9✔
209
            initial_kwargs.pop(kw, None)
9✔
210

211
    if plot_vectorfield:
9✔
212
        kwargs_local = _create_kwargs(
×
213
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
214
        out[1] = vectorfield(
×
215
            sys, pointdata, _check_kwargs=False, **kwargs_local)
216

217
        # Get rid of keyword arguments handled by vectorfield
218
        for kw in ['color', 'params']:
×
219
            initial_kwargs.pop(kw, None)
×
220

221
    if plot_streamplot:
9✔
222
        kwargs_local = _create_kwargs(
9✔
223
            kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
224
        streamplot(
9✔
225
            sys, pointdata, _check_kwargs=False, **kwargs_local)
226

227
        # Get rid of keyword arguments handled by streamplot
228
        for kw in ['color', 'params']:
9✔
229
            initial_kwargs.pop(kw, None)
9✔
230

231
    if plot_equilpoints:
9✔
232
        kwargs_local = _create_kwargs(
9✔
233
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
234
        out[2] = equilpoints(
9✔
235
            sys, pointdata, _check_kwargs=False, **kwargs_local)
236

237
        # Get rid of keyword arguments handled by equilpoints
238
        for kw in ['params']:
9✔
239
            initial_kwargs.pop(kw, None)
9✔
240

241
    # Make sure all keyword arguments were used
242
    if initial_kwargs:
9✔
243
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
9✔
244

245
    if user_ax is None:
9✔
246
        if title is None:
9✔
247
            title = f"Phase portrait for {sys.name}"
9✔
248
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
9✔
249
        ax.set_xlabel(sys.state_labels[0])
9✔
250
        ax.set_ylabel(sys.state_labels[1])
9✔
251
        plt.tight_layout()
9✔
252

253
    return ControlPlot(out, ax, fig)
9✔
254

255

256
def vectorfield(
9✔
257
        sys, pointdata, gridspec=None, ax=None, suppress_warnings=False,
258
        _check_kwargs=True, **kwargs):
259
    """Plot a vector field in the phase plane.
260

261
    This function plots a vector field for a two-dimensional state
262
    space system.
263

264
    Parameters
265
    ----------
266
    sys : NonlinearIOSystem or callable(t, x, ...)
267
        I/O system or function used to generate phase plane data.  If a
268
        function is given, the remaining arguments are drawn from the
269
        `params` keyword.
270
    pointdata : list or 2D array
271
        List of the form [xmin, xmax, ymin, ymax] describing the
272
        boundaries of the phase plot or an array of shape (N, 2)
273
        giving points of at which to plot the vector field.
274
    gridtype : str, optional
275
        The type of grid to use for generating initial conditions:
276
        'meshgrid' (default) generates a mesh of initial conditions within
277
        the specified boundaries, 'boxgrid' generates initial conditions
278
        along the edges of the boundary, 'circlegrid' generates a circle of
279
        initial conditions around each point in point data.
280
    gridspec : list, optional
281
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
282
        size of the grid in the x and y axes on which to generate points.
283
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
284
        specifying the radius and number of points around each point in the
285
        `pointdata` array.
286
    params : dict or list, optional
287
        Parameters to pass to system. For an I/O system, `params` should be
288
        a dict of parameters and values. For a callable, `params` should be
289
        dict with key 'args' and value given by a tuple (passed to callable).
290
    color : matplotlib color spec, optional
291
        Plot the vector field in the given color.
292
    ax : matplotlib.axes.Axes
293
        Use the given axes for the plot, otherwise use the current axes.
294

295
    Returns
296
    -------
297
    out : Quiver
298

299
    Other parameters
300
    ----------------
301
    rcParams : dict
302
        Override the default parameters used for generating plots.
303
        Default is set by config.default['ctrlplot.rcParams'].
304
    suppress_warnings : bool, optional
305
        If set to `True`, suppress warning messages in generating trajectories.
306

307
    """
308
    # Process keywords
309
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
310

311
    # Get system parameters
312
    params = kwargs.pop('params', None)
9✔
313

314
    # Create system from callable, if needed
315
    sys = _create_system(sys, params)
9✔
316

317
    # Determine the points on which to generate the vector field
318
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
319

320
    # Create axis if needed
321
    if ax is None:
9✔
322
        ax = plt.gca()
9✔
323

324
    # Set the plotting limits
325
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
326

327
    # Figure out the color to use
328
    color = _get_color(kwargs, ax=ax)
9✔
329

330
    # Make sure all keyword arguments were processed
331
    if _check_kwargs and kwargs:
9✔
332
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
333

334
    # Generate phase plane (quiver) data
335
    vfdata = np.zeros((points.shape[0], 4))
9✔
336
    sys._update_params(params)
9✔
337
    for i, x in enumerate(points):
9✔
338
        vfdata[i, :2] = x
9✔
339
        vfdata[i, 2:] = sys._rhs(0, x, np.zeros(sys.ninputs))
9✔
340

341
    with plt.rc_context(rcParams):
9✔
342
        out = ax.quiver(
9✔
343
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
344
            angles='xy', color=color)
345

346
    return out
9✔
347

348

349
def streamplot(
9✔
350
        sys, pointdata, gridspec=None, ax=None, vary_color=False,
351
        vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
352
        _check_kwargs=True, **kwargs):
353
    """Plot a vector field in the phase plane.
354

355
    This function plots a vector field for a two-dimensional state
356
    space system.
357

358
    Parameters
359
    ----------
360
    sys : NonlinearIOSystem or callable(t, x, ...)
361
        I/O system or function used to generate phase plane data.  If a
362
        function is given, the remaining arguments are drawn from the
363
        `params` keyword.
364
    pointdata : list or 2D array
365
        List of the form [xmin, xmax, ymin, ymax] describing the
366
        boundaries of the phase plot or an array of shape (N, 2)
367
        giving points from which to make the streamplot. In the latter case,
368
        the points lie on a grid like that generated by `meshgrid`.
369
    gridspec : list, optional
370
        Specifies the size of the grid in the x and y axes on which to
371
        generate points.
372
    params : dict or list, optional
373
        Parameters to pass to system. For an I/O system, `params` should be
374
        a dict of parameters and values. For a callable, `params` should be
375
        dict with key 'args' and value given by a tuple (passed to callable).
376
    color : matplotlib color spec, optional
377
        Plot the vector field in the given color.
378
    vary_color : bool, optional
379
        If set to `True`, vary the color of the streamlines based on the magnitude
380
    vary_linewidth : bool, optional
381
        If set to `True`, vary the linewidth of the streamlines based on the magnitude
382
    cmap : str or Colormap, optional
383
        Colormap to use for varying the color of the streamlines
384
    norm : Normalize, optional
385
        An instance of Normalize to use for scaling the colormap and linewidths
386
    ax : matplotlib.axes.Axes
387
        Use the given axes for the plot, otherwise use the current axes.
388

389
    Returns
390
    -------
391
    out : Quiver
392

393
    Other parameters
394
    ----------------
395
    rcParams : dict
396
        Override the default parameters used for generating plots.
397
        Default is set by config.default['ctrlplot.rcParams'].
398
    suppress_warnings : bool, optional
399
        If set to `True`, suppress warning messages in generating trajectories.
400

401
    """
402
    # Process keywords
403
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
404

405
    # Get system parameters
406
    params = kwargs.pop('params', None)
9✔
407

408
    # Create system from callable, if needed
409
    sys = _create_system(sys, params)
9✔
410

411
    # Determine the points on which to generate the streamplot field
412
    points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
9✔
413

414
    # attempt to recover the grid by counting the jumps in xvals
415
    if gridspec is None:
9✔
416
        nrows = np.sum(np.diff(points[:, 0]) < 0) + 1
×
417
        ncols = points.shape[0] // nrows
×
418
        if nrows * ncols != points.shape[0]:
×
419
            raise ValueError("Could not recover grid from points.")
×
420
        gridspec = [nrows, ncols]
×
421

422
    grid_arr_shape = gridspec[::-1]
9✔
423
    xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
9✔
424

425
    # Create axis if needed
426
    if ax is None:
9✔
427
        ax = plt.gca()
9✔
428

429
    # Set the plotting limits
430
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
431

432
    # Figure out the color to use
433
    color = _get_color(kwargs, ax=ax)
9✔
434

435
    # Make sure all keyword arguments were processed
436
    if _check_kwargs and kwargs:
9✔
437
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
438

439
    # Generate phase plane (quiver) data
440
    sys._update_params(params)
9✔
441
    us_flat, vs_flat = np.transpose([sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
9✔
442
    us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
9✔
443

444
    magnitudes = np.linalg.norm([us, vs], axis=0)
9✔
445
    norm = norm or mpl.colors.Normalize()
9✔
446
    normalized = norm(magnitudes)
9✔
447
    cmap =  plt.get_cmap(cmap)
9✔
448

449
    with plt.rc_context(rcParams):
9✔
450
        default_lw = plt.rcParams['lines.linewidth']
9✔
451
        min_lw, max_lw = 0.25*default_lw, 2*default_lw
9✔
452
        linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
9✔
453
        color = magnitudes if vary_color else color
9✔
454

455
        out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths, cmap=cmap, norm=norm)
9✔
456

457
    return out
9✔
458

459
def streamlines(
9✔
460
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
461
        ax=None, _check_kwargs=True, suppress_warnings=False, **kwargs):
462
    """Plot stream lines in the phase plane.
463

464
    This function plots stream lines for a two-dimensional state space
465
    system.
466

467
    Parameters
468
    ----------
469
    sys : NonlinearIOSystem or callable(t, x, ...)
470
        I/O system or function used to generate phase plane data.  If a
471
        function is given, the remaining arguments are drawn from the
472
        `params` keyword.
473
    pointdata : list or 2D array
474
        List of the form [xmin, xmax, ymin, ymax] describing the
475
        boundaries of the phase plot or an array of shape (N, 2)
476
        giving points of at which to plot the vector field.
477
    timedata : int or list of int
478
        Time to simulate each streamline.  If a list is given, a different
479
        time can be used for each initial condition in `pointdata`.
480
    gridtype : str, optional
481
        The type of grid to use for generating initial conditions:
482
        'meshgrid' (default) generates a mesh of initial conditions within
483
        the specified boundaries, 'boxgrid' generates initial conditions
484
        along the edges of the boundary, 'circlegrid' generates a circle of
485
        initial conditions around each point in point data.
486
    gridspec : list, optional
487
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
488
        size of the grid in the x and y axes on which to generate points.
489
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
490
        specifying the radius and number of points around each point in the
491
        `pointdata` array.
492
    dir : str, optional
493
        Direction to draw streamlines: 'forward' to flow forward in time
494
        from the reference points, 'reverse' to flow backward in time, or
495
        'both' to flow both forward and backward.  The amount of time to
496
        simulate in each direction is given by the ``timedata`` argument.
497
    params : dict or list, optional
498
        Parameters to pass to system. For an I/O system, `params` should be
499
        a dict of parameters and values. For a callable, `params` should be
500
        dict with key 'args' and value given by a tuple (passed to callable).
501
    color : str
502
        Plot the streamlines in the given color.
503
    ax : matplotlib.axes.Axes
504
        Use the given axes for the plot, otherwise use the current axes.
505

506
    Returns
507
    -------
508
    out : list of Line2D objects
509

510
    Other parameters
511
    ----------------
512
    rcParams : dict
513
        Override the default parameters used for generating plots.
514
        Default is set by config.default['ctrlplot.rcParams'].
515
    suppress_warnings : bool, optional
516
        If set to `True`, suppress warning messages in generating trajectories.
517

518
    """
519
    # Process keywords
520
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
521

522
    # Get system parameters
523
    params = kwargs.pop('params', None)
9✔
524

525
    # Create system from callable, if needed
526
    sys = _create_system(sys, params)
9✔
527

528
    # Parse the arrows keyword
529
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
530

531
    # Determine the points on which to generate the streamlines
532
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
9✔
533
    if dir is None:
9✔
534
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
9✔
535

536
    # Create axis if needed
537
    if ax is None:
9✔
538
        ax = plt.gca()
9✔
539

540
    # Set the axis limits
541
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
542

543
    # Figure out the color to use
544
    color = _get_color(kwargs, ax=ax)
9✔
545

546
    # Make sure all keyword arguments were processed
547
    if _check_kwargs and kwargs:
9✔
548
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
549

550
    # Create reverse time system, if needed
551
    if dir != 'forward':
9✔
552
        revsys = NonlinearIOSystem(
9✔
553
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
554
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
555
            outputs=sys.noutputs, params=sys.params)
556
    else:
557
        revsys = None
9✔
558

559
    # Generate phase plane (streamline) data
560
    out = []
9✔
561
    for i, X0 in enumerate(points):
9✔
562
        # Create the trajectory for this point
563
        timepts = _make_timepts(timedata, i)
9✔
564
        traj = _create_trajectory(
9✔
565
            sys, revsys, timepts, X0, params, dir,
566
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
567
            suppress_warnings=suppress_warnings)
568

569
        # Plot the trajectory (if there is one)
570
        if traj.shape[1] > 1:
9✔
571
            with plt.rc_context(rcParams):
9✔
572
                out += ax.plot(traj[0], traj[1], color=color)
9✔
573

574
                # Add arrows to the lines at specified intervals
575
                _add_arrows_to_line2D(
9✔
576
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
577
    return out
9✔
578

579

580
def equilpoints(
9✔
581
        sys, pointdata, gridspec=None, color='k', ax=None, _check_kwargs=True,
582
        **kwargs):
583
    """Plot equilibrium points in the phase plane.
584

585
    This function plots the equilibrium points for a planar dynamical system.
586

587
    Parameters
588
    ----------
589
    sys : NonlinearIOSystem or callable(t, x, ...)
590
        I/O system or function used to generate phase plane data. If a
591
        function is given, the remaining arguments are drawn from the
592
        `params` keyword.
593
    pointdata : list or 2D array
594
        List of the form [xmin, xmax, ymin, ymax] describing the
595
        boundaries of the phase plot or an array of shape (N, 2)
596
        giving points of at which to plot the vector field.
597
    gridtype : str, optional
598
        The type of grid to use for generating initial conditions:
599
        'meshgrid' (default) generates a mesh of initial conditions within
600
        the specified boundaries, 'boxgrid' generates initial conditions
601
        along the edges of the boundary, 'circlegrid' generates a circle of
602
        initial conditions around each point in point data.
603
    gridspec : list, optional
604
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
605
        size of the grid in the x and y axes on which to generate points.
606
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
607
        specifying the radius and number of points around each point in the
608
        `pointdata` array.
609
    params : dict or list, optional
610
        Parameters to pass to system. For an I/O system, `params` should be
611
        a dict of parameters and values. For a callable, `params` should be
612
        dict with key 'args' and value given by a tuple (passed to callable).
613
    color : str
614
        Plot the equilibrium points in the given color.
615
    ax : matplotlib.axes.Axes
616
        Use the given axes for the plot, otherwise use the current axes.
617

618
    Returns
619
    -------
620
    out : list of Line2D objects
621

622
    Other parameters
623
    ----------------
624
    rcParams : dict
625
        Override the default parameters used for generating plots.
626
        Default is set by config.default['ctrlplot.rcParams'].
627

628
    """
629
    # Process keywords
630
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
631

632
    # Get system parameters
633
    params = kwargs.pop('params', None)
9✔
634

635
    # Create system from callable, if needed
636
    sys = _create_system(sys, params)
9✔
637

638
    # Create axis if needed
639
    if ax is None:
9✔
640
        ax = plt.gca()
9✔
641

642
    # Set the axis limits
643
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
644

645
    # Determine the points on which to generate the vector field
646
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
647
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
648

649
    # Make sure all keyword arguments were processed
650
    if _check_kwargs and kwargs:
9✔
651
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
652

653
    # Search for equilibrium points
654
    equilpts = _find_equilpts(sys, points, params=params)
9✔
655

656
    # Plot the equilibrium points
657
    out = []
9✔
658
    for xeq in equilpts:
9✔
659
        with plt.rc_context(rcParams):
9✔
660
            out += ax.plot(xeq[0], xeq[1], marker='o', color=color)
9✔
661
    return out
9✔
662

663

664
def separatrices(
9✔
665
        sys, pointdata, timedata=None, gridspec=None, ax=None,
666
        _check_kwargs=True, suppress_warnings=False, **kwargs):
667
    """Plot separatrices in the phase plane.
668

669
    This function plots separatrices for a two-dimensional state space
670
    system.
671

672
    Parameters
673
    ----------
674
    sys : NonlinearIOSystem or callable(t, x, ...)
675
        I/O system or function used to generate phase plane data. If a
676
        function is given, the remaining arguments are drawn from the
677
        `params` keyword.
678
    pointdata : list or 2D array
679
        List of the form [xmin, xmax, ymin, ymax] describing the
680
        boundaries of the phase plot or an array of shape (N, 2)
681
        giving points of at which to plot the vector field.
682
    timedata : int or list of int
683
        Time to simulate each streamline.  If a list is given, a different
684
        time can be used for each initial condition in `pointdata`.
685
    gridtype : str, optional
686
        The type of grid to use for generating initial conditions:
687
        'meshgrid' (default) generates a mesh of initial conditions within
688
        the specified boundaries, 'boxgrid' generates initial conditions
689
        along the edges of the boundary, 'circlegrid' generates a circle of
690
        initial conditions around each point in point data.
691
    gridspec : list, optional
692
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
693
        size of the grid in the x and y axes on which to generate points.
694
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
695
        specifying the radius and number of points around each point in the
696
        `pointdata` array.
697
    params : dict or list, optional
698
        Parameters to pass to system. For an I/O system, `params` should be
699
        a dict of parameters and values. For a callable, `params` should be
700
        dict with key 'args' and value given by a tuple (passed to callable).
701
    color : matplotlib color spec, optional
702
        Plot the separatrics in the given color.  If a single color
703
        specification is given, this is used for both stable and unstable
704
        separatrices.  If a tuple is given, the first element is used as
705
        the color specification for stable separatrices and the second
706
        elmeent for unstable separatrices.
707
    ax : matplotlib.axes.Axes
708
        Use the given axes for the plot, otherwise use the current axes.
709

710
    Returns
711
    -------
712
    out : list of Line2D objects
713

714
    Other parameters
715
    ----------------
716
    rcParams : dict
717
        Override the default parameters used for generating plots.
718
        Default is set by config.default['ctrlplot.rcParams'].
719
    suppress_warnings : bool, optional
720
        If set to `True`, suppress warning messages in generating trajectories.
721

722
    """
723
    # Process keywords
724
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
725

726
    # Get system parameters
727
    params = kwargs.pop('params', None)
9✔
728

729
    # Create system from callable, if needed
730
    sys = _create_system(sys, params)
9✔
731

732
    # Parse the arrows keyword
733
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
734

735
    # Determine the initial states to use in searching for equilibrium points
736
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
737
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
738

739
    # Find the equilibrium points
740
    equilpts = _find_equilpts(sys, points, params=params)
9✔
741
    radius = config._get_param('phaseplot', 'separatrices_radius')
9✔
742

743
    # Create axis if needed
744
    if ax is None:
9✔
745
        ax = plt.gca()
9✔
746

747
    # Set the axis limits
748
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
749

750
    # Figure out the color to use for stable, unstable subspaces
751
    color = _get_color(kwargs)
9✔
752
    match color:
9✔
753
        case None:
9✔
754
            stable_color = 'r'
9✔
755
            unstable_color = 'b'
9✔
756
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
9✔
757
            pass
9✔
758
        case single_color:
9✔
759
            stable_color = unstable_color = color
9✔
760

761
    # Make sure all keyword arguments were processed
762
    if _check_kwargs and kwargs:
9✔
763
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
764

765
    # Create a "reverse time" system to use for simulation
766
    revsys = NonlinearIOSystem(
9✔
767
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
768
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
769
        outputs=sys.noutputs, params=sys.params)
770

771
    # Plot separatrices by flowing backwards in time along eigenspaces
772
    out = []
9✔
773
    for i, xeq in enumerate(equilpts):
9✔
774
        # Plot the equilibrium points
775
        with plt.rc_context(rcParams):
9✔
776
            out += ax.plot(xeq[0], xeq[1], marker='o', color='k')
9✔
777

778
        # Figure out the linearization and eigenvectors
779
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
9✔
780

781
        # See if we have real eigenvalues (=> evecs are meaningful)
782
        if evals[0].imag > 0:
9✔
783
            continue
9✔
784

785
        # Create default list of time points
786
        if timedata is not None:
9✔
787
            timepts = _make_timepts(timedata, i)
9✔
788

789
        # Generate the traces
790
        for j, dir in enumerate(evecs.T):
9✔
791
            # Figure out time vector if not yet computed
792
            if timedata is None:
9✔
793
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
9✔
794
                timepts = np.linspace(0, timescale)
9✔
795

796
            # Run the trajectory starting in eigenvector directions
797
            for eps in [-radius, radius]:
9✔
798
                x0 = xeq + dir * eps
9✔
799
                if evals[j].real < 0:
9✔
800
                    traj = _create_trajectory(
9✔
801
                        sys, revsys, timepts, x0, params, 'reverse',
802
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
803
                        suppress_warnings=suppress_warnings)
804
                    color = stable_color
9✔
805
                    linestyle = '--'
9✔
806
                elif evals[j].real > 0:
9✔
807
                    traj = _create_trajectory(
9✔
808
                        sys, revsys, timepts, x0, params, 'forward',
809
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
810
                        suppress_warnings=suppress_warnings)
811
                    color = unstable_color
9✔
812
                    linestyle = '-'
9✔
813

814
                # Plot the trajectory (if there is one)
815
                if traj.shape[1] > 1:
9✔
816
                    with plt.rc_context(rcParams):
9✔
817
                        out += ax.plot(
9✔
818
                            traj[0], traj[1], color=color, linestyle=linestyle)
819

820
                    # Add arrows to the lines at specified intervals
821
                    with plt.rc_context(rcParams):
9✔
822
                        _add_arrows_to_line2D(
9✔
823
                            ax, out[-1], arrow_pos, arrowstyle=arrow_style,
824
                            dir=1)
825
    return out
9✔
826

827

828
#
829
# User accessible utility functions
830
#
831

832
# Utility function to generate boxgrid (in the form needed here)
833
def boxgrid(xvals, yvals):
9✔
834
    """Generate list of points along the edge of box.
835

836
    points = boxgrid(xvals, yvals) generates a list of points that
837
    corresponds to a grid given by the cross product of the x and y values.
838

839
    Parameters
840
    ----------
841
    xvals, yvals : 1D array-like
842
        Array of points defining the points on the lower and left edges of
843
        the box.
844

845
    Returns
846
    -------
847
    grid : 2D array
848
        Array with shape (p, 2) defining the points along the edges of the
849
        box, where p is the number of points around the edge.
850

851
    """
852
    return np.array(
9✔
853
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
854
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
855
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
856
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
857
    )
858

859

860
# Utility function to generate meshgrid (in the form needed here)
861
# TODO: add examples of using grid functions directly
862
def meshgrid(xvals, yvals):
9✔
863
    """Generate list of points forming a mesh.
864

865
    points = meshgrid(xvals, yvals) generates a list of points that
866
    corresponds to a grid given by the cross product of the x and y values.
867

868
    Parameters
869
    ----------
870
    xvals, yvals : 1D array-like
871
        Array of points defining the points on the lower and left edges of
872
        the box.
873

874
    Returns
875
    -------
876
    grid: 2D array
877
        Array of points with shape (n * m, 2) defining the mesh
878

879
    """
880
    xvals, yvals = np.meshgrid(xvals, yvals)
9✔
881
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
9✔
882
    grid[:, 0] = xvals.reshape(-1)
9✔
883
    grid[:, 1] = yvals.reshape(-1)
9✔
884

885
    return grid
9✔
886

887

888
# Utility function to generate circular grid
889
def circlegrid(centers, radius, num):
9✔
890
    """Generate list of points around a circle.
891

892
    points = circlegrid(centers, radius, num) generates a list of points
893
    that form a circle around a list of centers.
894

895
    Parameters
896
    ----------
897
    centers : 2D array-like
898
        Array of points with shape (p, 2) defining centers of the circles.
899
    radius : float
900
        Radius of the points to be generated around each center.
901
    num : int
902
        Number of points to generate around the circle.
903

904
    Returns
905
    -------
906
    grid: 2D array
907
        Array of points with shape (p * num, 2) defining the circles.
908

909
    """
910
    centers = np.atleast_2d(np.array(centers))
9✔
911
    grid = np.zeros((centers.shape[0] * num, 2))
9✔
912
    for i, center in enumerate(centers):
9✔
913
        grid[i * num: (i + 1) * num, :] = center + np.array([
9✔
914
            [radius * math.cos(theta), radius * math.sin(theta)] for
915
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
916
    return grid
9✔
917

918
#
919
# Internal utility functions
920
#
921

922
# Create a system from a callable
923
def _create_system(sys, params):
9✔
924
    if isinstance(sys, NonlinearIOSystem):
9✔
925
        if sys.nstates != 2:
9✔
926
            raise ValueError("system must be planar")
9✔
927
        return sys
9✔
928

929
    # Make sure that if params is present, it has 'args' key
930
    if params and not params.get('args', None):
9✔
931
        raise ValueError("params must be dict with key 'args'")
9✔
932

933
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
9✔
934
    _output = lambda t, x, u, params: np.array([])
9✔
935
    return NonlinearIOSystem(
9✔
936
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
937

938
# Set axis limits for the plot
939
def _set_axis_limits(ax, pointdata):
9✔
940
    # Get the current axis limits
941
    if ax.lines:
9✔
942
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
9✔
943
    else:
944
        # Nothing on the plot => always use new limits
945
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
9✔
946

947
    # Short utility function for updating axis limits
948
    def _update_limits(cur, new):
9✔
949
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
9✔
950

951
    # If we were passed a box, use that to update the limits
952
    if isinstance(pointdata, list) and len(pointdata) == 4:
9✔
953
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
9✔
954
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
9✔
955

956
    elif isinstance(pointdata, np.ndarray):
9✔
957
        pointdata = np.atleast_2d(pointdata)
9✔
958
        xlim = _update_limits(
9✔
959
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
960
        ylim = _update_limits(
9✔
961
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
962

963
    # Keep track of the largest dimension on the plot
964
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
9✔
965

966
    # Set the new limits
967
    ax.autoscale(enable=True, axis='x', tight=True)
9✔
968
    ax.autoscale(enable=True, axis='y', tight=True)
9✔
969
    ax.set_xlim(xlim)
9✔
970
    ax.set_ylim(ylim)
9✔
971

972
    return xlim, ylim, maxlim
9✔
973

974

975
# Find equilibrium points
976
def _find_equilpts(sys, points, params=None):
9✔
977
    equilpts = []
9✔
978
    for i, x0 in enumerate(points):
9✔
979
        # Look for an equilibrium point near this point
980
        xeq, ueq = find_operating_point(sys, x0, 0, params=params)
9✔
981

982
        if xeq is None:
9✔
983
            continue            # didn't find anything
9✔
984

985
        # See if we have already found this point
986
        seen = False
9✔
987
        for x in equilpts:
9✔
988
            if np.allclose(np.array(x), xeq):
9✔
989
                seen = True
9✔
990
        if seen:
9✔
991
            continue
9✔
992

993
        # Save a new point
994
        equilpts += [xeq.tolist()]
9✔
995

996
    return equilpts
9✔
997

998

999
def _make_points(pointdata, gridspec, gridtype):
9✔
1000
    # Check to see what type of data we got
1001
    if isinstance(pointdata, np.ndarray) and gridtype is None:
9✔
1002
        pointdata = np.atleast_2d(pointdata)
9✔
1003
        if pointdata.shape[1] == 2:
9✔
1004
            # Given a list of points => no action required
1005
            return pointdata, None
9✔
1006

1007
    # Utility function to parse (and check) input arguments
1008
    def _parse_args(defsize):
9✔
1009
        if gridspec is None:
9✔
1010
            return defsize
9✔
1011

1012
        elif not isinstance(gridspec, (list, tuple)) or \
9✔
1013
             len(gridspec) != len(defsize):
1014
            raise ValueError("invalid grid specification")
9✔
1015

1016
        return gridspec
9✔
1017

1018
    # Generate points based on grid type
1019
    match gridtype:
9✔
1020
        case 'boxgrid' | None:
9✔
1021
            gridspec = _parse_args([6, 4])
9✔
1022
            points = boxgrid(
9✔
1023
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1024
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1025

1026
        case 'meshgrid':
9✔
1027
            gridspec = _parse_args([9, 6])
9✔
1028
            points = meshgrid(
9✔
1029
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1030
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1031

1032
        case 'circlegrid':
9✔
1033
            gridspec = _parse_args((0.5, 10))
9✔
1034
            if isinstance(pointdata, np.ndarray):
9✔
1035
                # Create circles around each point
1036
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
9✔
1037
            else:
1038
                # Create circle around center of the plot
1039
                points = circlegrid(
9✔
1040
                    np.array(
1041
                        [(pointdata[0] + pointdata[1]) / 2,
1042
                         (pointdata[0] + pointdata[1]) / 2]),
1043
                    gridspec[0], gridspec[1])
1044

1045
        case _:
9✔
1046
            raise ValueError(f"unknown grid type '{gridtype}'")
9✔
1047

1048
    return points, gridspec
9✔
1049

1050

1051
def _parse_arrow_keywords(kwargs):
9✔
1052
    # Get values for params (and pop from list to allow keyword use in plot)
1053
    # TODO: turn this into a utility function (shared with nyquist_plot?)
1054
    arrows = config._get_param(
9✔
1055
        'phaseplot', 'arrows', kwargs, None, pop=True)
1056
    arrow_size = config._get_param(
9✔
1057
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
1058
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
9✔
1059

1060
    # Parse the arrows keyword
1061
    if not arrows:
9✔
1062
        arrow_pos = []
×
1063
    elif isinstance(arrows, int):
9✔
1064
        N = arrows
9✔
1065
        # Space arrows out, starting midway along each "region"
1066
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
9✔
1067
    elif isinstance(arrows, (list, np.ndarray)):
×
1068
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
1069
    else:
1070
        raise ValueError("unknown or unsupported arrow location")
×
1071

1072
    # Set the arrow style
1073
    if arrow_style is None:
9✔
1074
        arrow_style = mpl.patches.ArrowStyle(
9✔
1075
            'simple', head_width=int(2 * arrow_size / 3),
1076
            head_length=arrow_size)
1077

1078
    return arrow_pos, arrow_style
9✔
1079

1080

1081
# TODO: move to ctrlplot?
1082
def _create_trajectory(
9✔
1083
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
1084
        gridtype=None, gridspec=None, xlim=None, ylim=None):
1085
    # Comput ethe forward trajectory
1086
    if dir == 'forward' or dir == 'both':
9✔
1087
        fwdresp = input_output_response(
9✔
1088
            sys, timepts, X0=X0, params=params, ignore_errors=True)
1089
        if not fwdresp.success and not suppress_warnings:
9✔
1090
            warnings.warn(f"{X0=}, {fwdresp.message}")
9✔
1091

1092
    # Compute the reverse trajectory
1093
    if dir == 'reverse' or dir == 'both':
9✔
1094
        revresp = input_output_response(
9✔
1095
            revsys, timepts, X0=X0, params=params, ignore_errors=True)
1096
        if not revresp.success and not suppress_warnings:
9✔
1097
            warnings.warn(f"{X0=}, {revresp.message}")
×
1098

1099
    # Create the trace to plot
1100
    if dir == 'forward':
9✔
1101
        traj = fwdresp.states
9✔
1102
    elif dir == 'reverse':
9✔
1103
        traj = revresp.states[:, ::-1]
9✔
1104
    elif dir == 'both':
9✔
1105
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
9✔
1106

1107
    # Remove points outside the window (keep first point beyond boundary)
1108
    inrange = np.asarray(
9✔
1109
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
1110
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
1111
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
9✔
1112
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
9✔
1113

1114
    return traj[:, inrange]
9✔
1115

1116

1117
def _make_timepts(timepts, i):
9✔
1118
    if timepts is None:
9✔
1119
        return np.linspace(0, 1)
9✔
1120
    elif isinstance(timepts, (int, float)):
9✔
1121
        return np.linspace(0, timepts)
9✔
1122
    elif timepts.ndim == 2:
×
1123
        return timepts[i]
×
1124
    return timepts
×
1125

1126

1127
#
1128
# Legacy phase plot function
1129
#
1130
# Author: Richard Murray
1131
# Date: 24 July 2011, converted from MATLAB version (2002); based on
1132
# a version by Kristi Morgansen
1133
#
1134
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
9✔
1135
               lingrid=None, lintime=None, logtime=None, timepts=None,
1136
               parms=None, params=(), tfirst=False, verbose=True):
1137

1138
    """(legacy) Phase plot for 2D dynamical systems.
1139

1140
    .. deprecated:: 0.10.1
1141
        This function is deprecated; use :func:`phase_plane_plot` instead.
1142

1143
    Produces a vector field or stream line plot for a planar system.  This
1144
    function has been replaced by the :func:`~control.phase_plane_map` and
1145
    :func:`~control.phase_plane_plot` functions.
1146

1147
    Call signatures:
1148
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
1149
      phase_plot(func, X, Y, scale, ...) - scale arrows
1150
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
1151
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
1152
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
1153
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
1154

1155
    Parameters
1156
    ----------
1157
    func : callable(x, t, ...)
1158
        Computes the time derivative of y (compatible with odeint).  The
1159
        function should be the same for as used for :mod:`scipy.integrate`.
1160
        Namely, it should be a function of the form dxdt = F(t, x) that
1161
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1162
        dimension 2.
1163
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1164
        Two 3-element sequences specifying x and y coordinates of a
1165
        grid.  These arguments are passed to linspace and meshgrid to
1166
        generate the points at which the vector field is plotted.  If
1167
        absent (or None), the vector field is not plotted.
1168
    scale: float, optional
1169
        Scale size of arrows; default = 1
1170
    X0: ndarray of initial conditions, optional
1171
        List of initial conditions from which streamlines are plotted.
1172
        Each initial condition should be a pair of numbers.
1173
    T: array-like or number, optional
1174
        Length of time to run simulations that generate streamlines.
1175
        If a single number, the same simulation time is used for all
1176
        initial conditions.  Otherwise, should be a list of length
1177
        len(X0) that gives the simulation time for each initial
1178
        condition.  Default value = 50.
1179
    lingrid : integer or 2-tuple of integers, optional
1180
        Argument is either N or (N, M).  If X0 is given and X, Y are missing,
1181
        a grid of arrows is produced using the limits of the initial
1182
        conditions, with N grid points in each dimension or N grid points in x
1183
        and M grid points in y.
1184
    lintime : integer or tuple (integer, float), optional
1185
        If a single integer N is given, draw N arrows using equally space time
1186
        points.  If a tuple (N, lambda) is given, draw N arrows using
1187
        exponential time constant lambda
1188
    timepts : array-like, optional
1189
        Draw arrows at the given list times [t1, t2, ...]
1190
    tfirst : bool, optional
1191
        If True, call `func` with signature `func(t, x, ...)`.
1192
    params: tuple, optional
1193
        List of parameters to pass to vector field: `func(x, t, *params)`
1194

1195
    See also
1196
    --------
1197
    box_grid : construct box-shaped grid of initial conditions
1198

1199
    """
1200
    # Generate a deprecation warning
1201
    warnings.warn(
9✔
1202
        "phase_plot() is deprecated; use phase_plane_plot() instead",
1203
        FutureWarning)
1204

1205
    #
1206
    # Figure out ranges for phase plot (argument processing)
1207
    #
1208
    #! TODO: need to add error checking to arguments
1209
    #! TODO: think through proper action if multiple options are given
1210
    #
1211
    autoFlag = False
9✔
1212
    logtimeFlag = False
9✔
1213
    timeptsFlag = False
9✔
1214
    Narrows = 0
9✔
1215

1216
    # Get parameters to pass to function
1217
    if parms:
9✔
1218
        warnings.warn(
9✔
1219
            f"keyword 'parms' is deprecated; use 'params'", FutureWarning)
1220
        if params:
9✔
1221
            raise ControlArgument(f"duplicate keywords 'parms' and 'params'")
×
1222
        else:
1223
            params = parms
9✔
1224

1225
    if lingrid is not None:
9✔
1226
        autoFlag = True
9✔
1227
        Narrows = lingrid
9✔
1228
        if (verbose):
9✔
1229
            print('Using auto arrows\n')
×
1230

1231
    elif logtime is not None:
9✔
1232
        logtimeFlag = True
9✔
1233
        Narrows = logtime[0]
9✔
1234
        timefactor = logtime[1]
9✔
1235
        if (verbose):
9✔
1236
            print('Using logtime arrows\n')
×
1237

1238
    elif timepts is not None:
9✔
1239
        timeptsFlag = True
9✔
1240
        Narrows = len(timepts)
9✔
1241

1242
    # Figure out the set of points for the quiver plot
1243
    #! TODO: Add sanity checks
1244
    elif X is not None and Y is not None:
9✔
1245
        x1, x2 = np.meshgrid(
9✔
1246
            np.linspace(X[0], X[1], X[2]),
1247
            np.linspace(Y[0], Y[1], Y[2]))
1248
        Narrows = len(x1)
9✔
1249

1250
    else:
1251
        # If we weren't given any grid points, don't plot arrows
1252
        Narrows = 0
9✔
1253

1254
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
9✔
1255
        # Now calculate the vector field at those points
1256
        (nr,nc) = x1.shape
9✔
1257
        dx = np.empty((nr, nc, 2))
9✔
1258
        for i in range(nr):
9✔
1259
            for j in range(nc):
9✔
1260
                if tfirst:
9✔
1261
                    dx[i, j, :] = np.squeeze(
×
1262
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1263
                else:
1264
                    dx[i, j, :] = np.squeeze(
9✔
1265
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1266

1267
        # Plot the quiver plot
1268
        #! TODO: figure out arguments to make arrows show up correctly
1269
        if scale is None:
9✔
1270
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1271
        elif (scale != 0):
9✔
1272
            #! TODO: optimize parameters for arrows
1273
            #! TODO: figure out arguments to make arrows show up correctly
1274
            xy = plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
9✔
1275
                            dx[:,:,1]*np.abs(scale), angles='xy')
1276
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1277

1278
        #! TODO: Tweak the shape of the plot
1279
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1280
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1281
        plt.xlabel('x1'); plt.ylabel('x2')
9✔
1282

1283
    # See if we should also generate the streamlines
1284
    if X0 is None or len(X0) == 0:
9✔
1285
        return
9✔
1286

1287
    # Convert initial conditions to a numpy array
1288
    X0 = np.array(X0)
9✔
1289
    (nr, nc) = np.shape(X0)
9✔
1290

1291
    # Generate some empty matrices to keep arrow information
1292
    x1 = np.empty((nr, Narrows))
9✔
1293
    x2 = np.empty((nr, Narrows))
9✔
1294
    dx = np.empty((nr, Narrows, 2))
9✔
1295

1296
    # See if we were passed a simulation time
1297
    if T is None:
9✔
1298
        T = 50
9✔
1299

1300
    # Parse the time we were passed
1301
    TSPAN = T
9✔
1302
    if isinstance(T, (int, float)):
9✔
1303
        TSPAN = np.linspace(0, T, 100)
9✔
1304

1305
    # Figure out the limits for the plot
1306
    if scale is None:
9✔
1307
        # Assume that the current axis are set as we want them
1308
        alim = plt.axis()
×
1309
        xmin = alim[0]; xmax = alim[1]
×
1310
        ymin = alim[2]; ymax = alim[3]
×
1311
    else:
1312
        # Use the maximum extent of all trajectories
1313
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
9✔
1314
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
9✔
1315

1316
    # Generate the streamlines for each initial condition
1317
    for i in range(nr):
9✔
1318
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
9✔
1319
        time = TSPAN
9✔
1320

1321
        plt.plot(state[:,0], state[:,1])
9✔
1322
        #! TODO: add back in colors for stream lines
1323
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1324
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1325

1326
        # Plot arrows if quiver parameters were 'auto'
1327
        if autoFlag or logtimeFlag or timeptsFlag:
9✔
1328
            # Compute the locations of the arrows
1329
            #! TODO: check this logic to make sure it works in python
1330
            for j in range(Narrows):
9✔
1331

1332
                # Figure out starting index; headless arrows start at 0
1333
                k = -1 if scale is None else 0
9✔
1334

1335
                # Figure out what time index to use for the next point
1336
                if autoFlag:
9✔
1337
                    # Use a linear scaling based on ODE time vector
1338
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1339
                elif logtimeFlag:
9✔
1340
                    # Use an exponential time vector
1341
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1342
                    tarr = _find(time < (j-k) / timefactor)
9✔
1343
                    tind = tarr[-1] if len(tarr) else 0
9✔
1344
                elif timeptsFlag:
9✔
1345
                    # Use specified time points
1346
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1347
                    tarr = _find(time < timepts[j])
9✔
1348
                    tind = tarr[-1] if len(tarr) else 0
9✔
1349

1350
                # For tailless arrows, skip the first point
1351
                if tind == 0 and scale is None:
9✔
1352
                    continue
×
1353

1354
                # Figure out the arrow at this point on the curve
1355
                x1[i,j] = state[tind, 0]
9✔
1356
                x2[i,j] = state[tind, 1]
9✔
1357

1358
                # Skip arrows outside of initial condition box
1359
                if (scale is not None or
9✔
1360
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1361
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1362
                    if tfirst:
9✔
1363
                        pass
×
1364
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1365
                    else:
1366
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
9✔
1367
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
9✔
1368
                else:
1369
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1370

1371
    # Set the plot shape before plotting arrows to avoid warping
1372
    # a=gca
1373
    # if (scale != None):
1374
    #     set(a,'DataAspectRatio', [1,1,1])
1375
    # if (xmin != xmax and ymin != ymax):
1376
    #     plt.axis([xmin, xmax, ymin, ymax])
1377
    # set(a, 'Box', 'on')
1378

1379
    # Plot arrows on the streamlines
1380
    if scale is None and Narrows > 0:
9✔
1381
        # Use a tailless arrow
1382
        #! TODO: figure out arguments to make arrows show up correctly
1383
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1384
    elif scale != 0 and Narrows > 0:
9✔
1385
        #! TODO: figure out arguments to make arrows show up correctly
1386
        xy = plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
9✔
1387
                        angles='xy')
1388
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1389
        # set(xy, 'AutoScale', 'off')
1390
        # set(xy, 'AutoScaleFactor', 0)
1391

1392
    if scale < 0:
9✔
1393
        bp = plt.plot(x1, x2, 'b.');        # add dots at base
×
1394
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1395

1396

1397
# Utility function for generating initial conditions around a box
1398
def box_grid(xlimp, ylimp):
9✔
1399
    """box_grid   generate list of points on edge of box
1400

1401
    .. deprecated:: 0.10.0
1402
        Use :func:`phaseplot.boxgrid` instead.
1403

1404
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1405
    list of points that correspond to a uniform grid at the end of the
1406
    box defined by the corners [xmin ymin] and [xmax ymax].
1407

1408
    """
1409

1410
    # Generate a deprecation warning
1411
    warnings.warn(
×
1412
        "box_grid() is deprecated; use phaseplot.boxgrid() instead",
1413
        FutureWarning)
1414

1415
    return boxgrid(
×
1416
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1417
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1418

1419

1420
# TODO: rename to something more useful (or remove??)
1421
def _find(condition):
9✔
1422
    """Returns indices where ravel(a) is true.
1423
    Private implementation of deprecated matplotlib.mlab.find
1424
    """
1425
    return np.nonzero(np.ravel(condition))[0]
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc