• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 13099029383

02 Feb 2025 01:00PM UTC coverage: 94.632% (-0.03%) from 94.659%
13099029383

Pull #1112

github

web-flow
Merge b63579848 into d2f9a9c54
Pull Request #1112: Use matplotlibs streamplot function for phase_plane_plot

9696 of 10246 relevant lines covered (94.63%)

8.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.95
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Author: Richard M. Murray
4
# Date:   23 Mar 2024 (legacy version information below)
5
#
6
# TODO
7
# * Allow multiple timepoints (and change timespec name to T?)
8
# * Update linestyles (color -> linestyle?)
9
# * Check for keyword compatibility with other plot routines
10
# * Set up configuration parameters (nyquist --> phaseplot)
11

12
"""Module for generating 2D phase plane plots.
13

14
The :mod:`control.phaseplot` module contains functions for generating 2D
15
phase plots. The base function for creating phase plane portraits is
16
:func:`~control.phase_plane_plot`, which generates a phase plane portrait
17
for a 2 state I/O system (with no inputs).  In addition, several other
18
functions are available to create customized phase plane plots:
19

20
* boxgrid: Generate a list of points along the edge of a box
21
* circlegrid: Generate list of points around a circle
22
* equilpoints: Plot equilibrium points in the phase plane
23
* meshgrid: Generate a list of points forming a mesh
24
* separatrices: Plot separatrices in the phase plane
25
* streamlines: Plot stream lines in the phase plane
26
* vectorfield: Plot a vector field in the phase plane
27
* streampot: Plot streamlines using matplotlib's streamplot function
28

29
"""
30

31
import math
9✔
32
import warnings
9✔
33

34
import matplotlib as mpl
9✔
35
import matplotlib.pyplot as plt
9✔
36
import numpy as np
9✔
37
from scipy.integrate import odeint
9✔
38

39
from . import config
9✔
40
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
9✔
41
    _process_ax_keyword, _update_plot_title
42
from .exception import ControlNotImplemented
9✔
43
from .nlsys import NonlinearIOSystem, find_operating_point, \
9✔
44
    input_output_response
45

46
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
9✔
47

48
# Default values for module parameter variables
49
_phaseplot_defaults = {
9✔
50
    'phaseplot.arrows': 2,                  # number of arrows around curve
51
    'phaseplot.arrow_size': 8,              # pixel size for arrows
52
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
53
}
54

55
def phase_plane_plot(
9✔
56
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
57
        plot_streamlines=True, plot_vectorfield=False, plot_equilpoints=True,
58
        plot_separatrices=True, ax=None, suppress_warnings=False, title=None,
59
        plot_streamplot=False, **kwargs
60
):
61
    """Plot phase plane diagram.
62

63
    This function plots phase plane data, including vector fields, stream
64
    lines, equilibrium points, and contour curves.
65

66
    Parameters
67
    ----------
68
    sys : NonlinearIOSystem or callable(t, x, ...)
69
        I/O system or function used to generate phase plane data. If a
70
        function is given, the remaining arguments are drawn from the
71
        `params` keyword.
72
    pointdata : list or 2D array
73
        List of the form [xmin, xmax, ymin, ymax] describing the
74
        boundaries of the phase plot or an array of shape (N, 2)
75
        giving points of at which to plot the vector field.
76
    timedata : int or list of int
77
        Time to simulate each streamline.  If a list is given, a different
78
        time can be used for each initial condition in `pointdata`.
79
    gridtype : str, optional
80
        The type of grid to use for generating initial conditions:
81
        'meshgrid' (default) generates a mesh of initial conditions within
82
        the specified boundaries, 'boxgrid' generates initial conditions
83
        along the edges of the boundary, 'circlegrid' generates a circle of
84
        initial conditions around each point in point data.
85
    gridspec : list, optional
86
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
87
        size of the grid in the x and y axes on which to generate points.
88
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
89
        specifying the radius and number of points around each point in the
90
        `pointdata` array.
91
    params : dict, optional
92
        Parameters to pass to system. For an I/O system, `params` should be
93
        a dict of parameters and values. For a callable, `params` should be
94
        dict with key 'args' and value given by a tuple (passed to callable).
95
    color : matplotlib color spec, optional
96
        Plot all elements in the given color (use `plot_<fcn>={'color': c}`
97
        to set the color in one element of the phase plot.
98
    ax : matplotlib.axes.Axes, optional
99
        The matplotlib axes to draw the figure on.  If not specified and
100
        the current figure has a single axes, that axes is used.
101
        Otherwise, a new figure is created.
102

103
    Returns
104
    -------
105
    cplt : :class:`ControlPlot` object
106
        Object containing the data that were plotted:
107

108
          * cplt.lines: array of list of :class:`matplotlib.artist.Artist`
109
            objects:
110

111
              - lines[0] = list of Line2D objects (streamlines, separatrices).
112
              - lines[1] = Quiver object (vector field arrows).
113
              - lines[2] = list of Line2D objects (equilibrium points).
114
              - lines[3] = StreamplotSet object (lines with arrows).
115

116
          * cplt.axes: 2D array of :class:`matplotlib.axes.Axes` for the plot.
117

118
          * cplt.figure: :class:`matplotlib.figure.Figure` containing the plot.
119

120
        See :class:`ControlPlot` for more detailed information.
121

122

123
    Other parameters
124
    ----------------
125
    dir : str, optional
126
        Direction to draw streamlines: 'forward' to flow forward in time
127
        from the reference points, 'reverse' to flow backward in time, or
128
        'both' to flow both forward and backward.  The amount of time to
129
        simulate in each direction is given by the ``timedata`` argument.
130
    plot_streamlines : bool or dict, optional
131
        If `True` (default) then plot streamlines based on the pointdata
132
        and gridtype.  If set to a dict, pass on the key-value pairs in
133
        the dict as keywords to :func:`~control.phaseplot.streamlines`.
134
    plot_vectorfield : bool or dict, optional
135
        If `True` (default) then plot the vector field based on the pointdata
136
        and gridtype.  If set to a dict, pass on the key-value pairs in
137
        the dict as keywords to :func:`~control.phaseplot.vectorfield`.
138
    plot_streamplot : bool or dict, optional
139
        If True then use :func:`matplotlib.axes.Axes.streamplot` function
140
        to plot the streamlines.  If set to a dict, pass on the key-value
141
        pairs in the dict as keywords to :func:`~control.phaseplot.streamplot`.
142
    plot_equilpoints : bool or dict, optional
143
        If `True` (default) then plot equilibrium points based in the phase
144
        plot boundary. If set to a dict, pass on the key-value pairs in the
145
        dict as keywords to :func:`~control.phaseplot.equilpoints`.
146
    plot_separatrices : bool or dict, optional
147
        If `True` (default) then plot separatrices starting from each
148
        equilibrium point.  If set to a dict, pass on the key-value pairs
149
        in the dict as keywords to :func:`~control.phaseplot.separatrices`.
150
    rcParams : dict
151
        Override the default parameters used for generating plots.
152
        Default is set by config.default['ctrlplot.rcParams'].
153
    suppress_warnings : bool, optional
154
        If set to `True`, suppress warning messages in generating trajectories.
155
    title : str, optional
156
        Set the title of the plot.  Defaults to plot type and system name(s).
157

158
    """
159
    # Process arguments
160
    params = kwargs.get('params', None)
9✔
161
    sys = _create_system(sys, params)
9✔
162
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
9✔
163
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
164

165
    # Create axis if needed
166
    user_ax = ax
9✔
167
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
9✔
168

169
    # Create copy of kwargs for later checking to find unused arguments
170
    initial_kwargs = dict(kwargs)
9✔
171
    passed_kwargs = False
9✔
172

173
    # Utility function to create keyword arguments
174
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
9✔
175
        new_kwargs = dict(global_kwargs)
9✔
176
        new_kwargs.update(other_kwargs)
9✔
177
        if isinstance(local_kwargs, dict):
9✔
178
            new_kwargs.update(local_kwargs)
9✔
179
        return new_kwargs
9✔
180

181
    # Create list for storing outputs
182
    out = np.array([[], None, None], dtype=object)
9✔
183

184
    # Plot out the main elements
185
    if plot_streamlines:
9✔
186
        kwargs_local = _create_kwargs(
9✔
187
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
188
            ax=ax)
189
        out[0] += streamlines(
9✔
190
            sys, pointdata, timedata, _check_kwargs=False,
191
            suppress_warnings=suppress_warnings, **kwargs_local)
192

193
        # Get rid of keyword arguments handled by streamlines
194
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
9✔
195
                   'dir', 'params']:
196
            initial_kwargs.pop(kw, None)
9✔
197

198
    # Reset the gridspec for the remaining commands, if needed
199
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
9✔
200
        gridspec = None
×
201

202
    if plot_separatrices:
9✔
203
        kwargs_local = _create_kwargs(
9✔
204
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
205
        out[0] += separatrices(
9✔
206
            sys, pointdata, _check_kwargs=False, **kwargs_local)
207

208
        # Get rid of keyword arguments handled by separatrices
209
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
9✔
210
            initial_kwargs.pop(kw, None)
9✔
211

212
    if plot_vectorfield:
9✔
213
        kwargs_local = _create_kwargs(
×
214
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
215
        out[1] = vectorfield(
×
216
            sys, pointdata, _check_kwargs=False, **kwargs_local)
217

218
        # Get rid of keyword arguments handled by vectorfield
219
        for kw in ['color', 'params']:
×
220
            initial_kwargs.pop(kw, None)
×
221

222
    if plot_streamplot:
9✔
223
        kwargs_local = _create_kwargs(
9✔
224
            kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
225
        streamplot(
9✔
226
            sys, pointdata, _check_kwargs=False, **kwargs_local)
227

228
        # Get rid of keyword arguments handled by streamplot
229
        for kw in ['color', 'params']:
9✔
230
            initial_kwargs.pop(kw, None)
9✔
231

232
    if plot_equilpoints:
9✔
233
        kwargs_local = _create_kwargs(
9✔
234
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
235
        out[2] = equilpoints(
9✔
236
            sys, pointdata, _check_kwargs=False, **kwargs_local)
237

238
        # Get rid of keyword arguments handled by equilpoints
239
        for kw in ['params']:
9✔
240
            initial_kwargs.pop(kw, None)
9✔
241

242
    # Make sure all keyword arguments were used
243
    if initial_kwargs:
9✔
244
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
9✔
245

246
    if user_ax is None:
9✔
247
        if title is None:
9✔
248
            title = f"Phase portrait for {sys.name}"
9✔
249
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
9✔
250
        ax.set_xlabel(sys.state_labels[0])
9✔
251
        ax.set_ylabel(sys.state_labels[1])
9✔
252
        plt.tight_layout()
9✔
253

254
    return ControlPlot(out, ax, fig)
9✔
255

256

257
def vectorfield(
9✔
258
        sys, pointdata, gridspec=None, ax=None, suppress_warnings=False,
259
        _check_kwargs=True, **kwargs):
260
    """Plot a vector field in the phase plane.
261

262
    This function plots a vector field for a two-dimensional state
263
    space system.
264

265
    Parameters
266
    ----------
267
    sys : NonlinearIOSystem or callable(t, x, ...)
268
        I/O system or function used to generate phase plane data.  If a
269
        function is given, the remaining arguments are drawn from the
270
        `params` keyword.
271
    pointdata : list or 2D array
272
        List of the form [xmin, xmax, ymin, ymax] describing the
273
        boundaries of the phase plot or an array of shape (N, 2)
274
        giving points of at which to plot the vector field.
275
    gridtype : str, optional
276
        The type of grid to use for generating initial conditions:
277
        'meshgrid' (default) generates a mesh of initial conditions within
278
        the specified boundaries, 'boxgrid' generates initial conditions
279
        along the edges of the boundary, 'circlegrid' generates a circle of
280
        initial conditions around each point in point data.
281
    gridspec : list, optional
282
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
283
        size of the grid in the x and y axes on which to generate points.
284
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
285
        specifying the radius and number of points around each point in the
286
        `pointdata` array.
287
    params : dict or list, optional
288
        Parameters to pass to system. For an I/O system, `params` should be
289
        a dict of parameters and values. For a callable, `params` should be
290
        dict with key 'args' and value given by a tuple (passed to callable).
291
    color : matplotlib color spec, optional
292
        Plot the vector field in the given color.
293
    ax : matplotlib.axes.Axes
294
        Use the given axes for the plot, otherwise use the current axes.
295

296
    Returns
297
    -------
298
    out : Quiver
299

300
    Other parameters
301
    ----------------
302
    rcParams : dict
303
        Override the default parameters used for generating plots.
304
        Default is set by config.default['ctrlplot.rcParams'].
305
    suppress_warnings : bool, optional
306
        If set to `True`, suppress warning messages in generating trajectories.
307

308
    """
309
    # Process keywords
310
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
311

312
    # Get system parameters
313
    params = kwargs.pop('params', None)
9✔
314

315
    # Create system from callable, if needed
316
    sys = _create_system(sys, params)
9✔
317

318
    # Determine the points on which to generate the vector field
319
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
320

321
    # Create axis if needed
322
    if ax is None:
9✔
323
        ax = plt.gca()
9✔
324

325
    # Set the plotting limits
326
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
327

328
    # Figure out the color to use
329
    color = _get_color(kwargs, ax=ax)
9✔
330

331
    # Make sure all keyword arguments were processed
332
    if _check_kwargs and kwargs:
9✔
333
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
334

335
    # Generate phase plane (quiver) data
336
    vfdata = np.zeros((points.shape[0], 4))
9✔
337
    sys._update_params(params)
9✔
338
    for i, x in enumerate(points):
9✔
339
        vfdata[i, :2] = x
9✔
340
        vfdata[i, 2:] = sys._rhs(0, x, np.zeros(sys.ninputs))
9✔
341

342
    with plt.rc_context(rcParams):
9✔
343
        out = ax.quiver(
9✔
344
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
345
            angles='xy', color=color)
346

347
    return out
9✔
348

349

350
def streamplot(
9✔
351
        sys, pointdata, gridspec=None, ax=None, vary_color=False,
352
        vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
353
        _check_kwargs=True, **kwargs):
354
    """Plot a vector field in the phase plane.
355

356
    This function plots a vector field for a two-dimensional state
357
    space system.
358

359
    Parameters
360
    ----------
361
    sys : `NonlinearIOSystem` or callable(t, x, ...)
362
        I/O system or function used to generate phase plane data.  If a
363
        function is given, the remaining arguments are drawn from the
364
        `params` keyword.
365
    pointdata : list or 2D array
366
        List of the form [xmin, xmax, ymin, ymax] describing the
367
        boundaries of the phase plot or an array of shape (N, 2)
368
        giving points from which to make the streamplot. In the latter case,
369
        the points lie on a grid like that generated by `meshgrid`.
370
    gridspec : list, optional
371
        Specifies the size of the grid in the x and y axes on which to
372
        generate points.
373
    params : dict or list, optional
374
        Parameters to pass to system. For an I/O system, `params` should be
375
        a dict of parameters and values. For a callable, `params` should be
376
        dict with key 'args' and value given by a tuple (passed to callable).
377
    color : matplotlib color spec, optional
378
        Plot the vector field in the given color.
379
    vary_color : bool, optional
380
        If set to True, vary the color of the streamlines based on the magnitude
381
    vary_linewidth : bool, optional
382
        If set to True, vary the linewidth of the streamlines based on the magnitude
383
    cmap : str or Colormap, optional
384
        Colormap to use for varying the color of the streamlines
385
    norm : `matplotlib.colors.Normalize`, optional
386
        An instance of Normalize to use for scaling the colormap and linewidths
387
    ax : `matplotlib.axes.Axes`, optional
388
        Use the given axes for the plot, otherwise use the current axes.
389

390
    Returns
391
    -------
392
    out : StreamplotSet
393

394
    Other parameters
395
    ----------------
396
    rcParams : dict
397
        Override the default parameters used for generating plots.
398
        Default is set by `config.default['ctrlplot.rcParams']`.
399
    suppress_warnings : bool, optional
400
        If set to True, suppress warning messages in generating trajectories.
401

402
    """
403
    # Process keywords
404
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
405

406
    # Get system parameters
407
    params = kwargs.pop('params', None)
9✔
408

409
    # Create system from callable, if needed
410
    sys = _create_system(sys, params)
9✔
411

412
    # Determine the points on which to generate the streamplot field
413
    points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
9✔
414

415
    # attempt to recover the grid by counting the jumps in xvals
416
    if gridspec is None:
9✔
417
        nrows = np.sum(np.diff(points[:, 0]) < 0) + 1
×
418
        ncols = points.shape[0] // nrows
×
419
        if nrows * ncols != points.shape[0]:
×
420
            raise ValueError("Could not recover grid from points.")
×
421
        gridspec = [nrows, ncols]
×
422

423
    grid_arr_shape = gridspec[::-1]
9✔
424
    xs, ys = points[:, 0].reshape(grid_arr_shape), points[:, 1].reshape(grid_arr_shape)
9✔
425

426
    # Create axis if needed
427
    if ax is None:
9✔
428
        ax = plt.gca()
9✔
429

430
    # Set the plotting limits
431
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
432

433
    # Figure out the color to use
434
    color = _get_color(kwargs, ax=ax)
9✔
435

436
    # Make sure all keyword arguments were processed
437
    if _check_kwargs and kwargs:
9✔
438
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
439

440
    # Generate phase plane (quiver) data
441
    sys._update_params(params)
9✔
442
    us_flat, vs_flat = np.transpose([sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
9✔
443
    us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
9✔
444

445
    magnitudes = np.linalg.norm([us, vs], axis=0)
9✔
446
    norm = norm or mpl.colors.Normalize()
9✔
447
    normalized = norm(magnitudes)
9✔
448
    cmap =  plt.get_cmap(cmap)
9✔
449

450
    with plt.rc_context(rcParams):
9✔
451
        default_lw = plt.rcParams['lines.linewidth']
9✔
452
        min_lw, max_lw = 0.25*default_lw, 2*default_lw
9✔
453
        linewidths = normalized * (max_lw - min_lw) + min_lw if vary_linewidth else None
9✔
454
        color = magnitudes if vary_color else color
9✔
455

456
        out = ax.streamplot(xs, ys, us, vs, color=color, linewidth=linewidths, cmap=cmap, norm=norm)
9✔
457

458
    return out
9✔
459

460
def streamlines(
9✔
461
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
462
        ax=None, _check_kwargs=True, suppress_warnings=False, **kwargs):
463
    """Plot stream lines in the phase plane.
464

465
    This function plots stream lines for a two-dimensional state space
466
    system.
467

468
    Parameters
469
    ----------
470
    sys : NonlinearIOSystem or callable(t, x, ...)
471
        I/O system or function used to generate phase plane data.  If a
472
        function is given, the remaining arguments are drawn from the
473
        `params` keyword.
474
    pointdata : list or 2D array
475
        List of the form [xmin, xmax, ymin, ymax] describing the
476
        boundaries of the phase plot or an array of shape (N, 2)
477
        giving points of at which to plot the vector field.
478
    timedata : int or list of int
479
        Time to simulate each streamline.  If a list is given, a different
480
        time can be used for each initial condition in `pointdata`.
481
    gridtype : str, optional
482
        The type of grid to use for generating initial conditions:
483
        'meshgrid' (default) generates a mesh of initial conditions within
484
        the specified boundaries, 'boxgrid' generates initial conditions
485
        along the edges of the boundary, 'circlegrid' generates a circle of
486
        initial conditions around each point in point data.
487
    gridspec : list, optional
488
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
489
        size of the grid in the x and y axes on which to generate points.
490
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
491
        specifying the radius and number of points around each point in the
492
        `pointdata` array.
493
    dir : str, optional
494
        Direction to draw streamlines: 'forward' to flow forward in time
495
        from the reference points, 'reverse' to flow backward in time, or
496
        'both' to flow both forward and backward.  The amount of time to
497
        simulate in each direction is given by the ``timedata`` argument.
498
    params : dict or list, optional
499
        Parameters to pass to system. For an I/O system, `params` should be
500
        a dict of parameters and values. For a callable, `params` should be
501
        dict with key 'args' and value given by a tuple (passed to callable).
502
    color : str
503
        Plot the streamlines in the given color.
504
    ax : matplotlib.axes.Axes
505
        Use the given axes for the plot, otherwise use the current axes.
506

507
    Returns
508
    -------
509
    out : list of Line2D objects
510

511
    Other parameters
512
    ----------------
513
    rcParams : dict
514
        Override the default parameters used for generating plots.
515
        Default is set by config.default['ctrlplot.rcParams'].
516
    suppress_warnings : bool, optional
517
        If set to `True`, suppress warning messages in generating trajectories.
518

519
    """
520
    # Process keywords
521
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
522

523
    # Get system parameters
524
    params = kwargs.pop('params', None)
9✔
525

526
    # Create system from callable, if needed
527
    sys = _create_system(sys, params)
9✔
528

529
    # Parse the arrows keyword
530
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
531

532
    # Determine the points on which to generate the streamlines
533
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
9✔
534
    if dir is None:
9✔
535
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
9✔
536

537
    # Create axis if needed
538
    if ax is None:
9✔
539
        ax = plt.gca()
9✔
540

541
    # Set the axis limits
542
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
543

544
    # Figure out the color to use
545
    color = _get_color(kwargs, ax=ax)
9✔
546

547
    # Make sure all keyword arguments were processed
548
    if _check_kwargs and kwargs:
9✔
549
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
550

551
    # Create reverse time system, if needed
552
    if dir != 'forward':
9✔
553
        revsys = NonlinearIOSystem(
9✔
554
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
555
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
556
            outputs=sys.noutputs, params=sys.params)
557
    else:
558
        revsys = None
9✔
559

560
    # Generate phase plane (streamline) data
561
    out = []
9✔
562
    for i, X0 in enumerate(points):
9✔
563
        # Create the trajectory for this point
564
        timepts = _make_timepts(timedata, i)
9✔
565
        traj = _create_trajectory(
9✔
566
            sys, revsys, timepts, X0, params, dir,
567
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
568
            suppress_warnings=suppress_warnings)
569

570
        # Plot the trajectory (if there is one)
571
        if traj.shape[1] > 1:
9✔
572
            with plt.rc_context(rcParams):
9✔
573
                out += ax.plot(traj[0], traj[1], color=color)
9✔
574

575
                # Add arrows to the lines at specified intervals
576
                _add_arrows_to_line2D(
9✔
577
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
578
    return out
9✔
579

580

581
def equilpoints(
9✔
582
        sys, pointdata, gridspec=None, color='k', ax=None, _check_kwargs=True,
583
        **kwargs):
584
    """Plot equilibrium points in the phase plane.
585

586
    This function plots the equilibrium points for a planar dynamical system.
587

588
    Parameters
589
    ----------
590
    sys : NonlinearIOSystem or callable(t, x, ...)
591
        I/O system or function used to generate phase plane data. If a
592
        function is given, the remaining arguments are drawn from the
593
        `params` keyword.
594
    pointdata : list or 2D array
595
        List of the form [xmin, xmax, ymin, ymax] describing the
596
        boundaries of the phase plot or an array of shape (N, 2)
597
        giving points of at which to plot the vector field.
598
    gridtype : str, optional
599
        The type of grid to use for generating initial conditions:
600
        'meshgrid' (default) generates a mesh of initial conditions within
601
        the specified boundaries, 'boxgrid' generates initial conditions
602
        along the edges of the boundary, 'circlegrid' generates a circle of
603
        initial conditions around each point in point data.
604
    gridspec : list, optional
605
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
606
        size of the grid in the x and y axes on which to generate points.
607
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
608
        specifying the radius and number of points around each point in the
609
        `pointdata` array.
610
    params : dict or list, optional
611
        Parameters to pass to system. For an I/O system, `params` should be
612
        a dict of parameters and values. For a callable, `params` should be
613
        dict with key 'args' and value given by a tuple (passed to callable).
614
    color : str
615
        Plot the equilibrium points in the given color.
616
    ax : matplotlib.axes.Axes
617
        Use the given axes for the plot, otherwise use the current axes.
618

619
    Returns
620
    -------
621
    out : list of Line2D objects
622

623
    Other parameters
624
    ----------------
625
    rcParams : dict
626
        Override the default parameters used for generating plots.
627
        Default is set by config.default['ctrlplot.rcParams'].
628

629
    """
630
    # Process keywords
631
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
632

633
    # Get system parameters
634
    params = kwargs.pop('params', None)
9✔
635

636
    # Create system from callable, if needed
637
    sys = _create_system(sys, params)
9✔
638

639
    # Create axis if needed
640
    if ax is None:
9✔
641
        ax = plt.gca()
9✔
642

643
    # Set the axis limits
644
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
645

646
    # Determine the points on which to generate the vector field
647
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
648
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
649

650
    # Make sure all keyword arguments were processed
651
    if _check_kwargs and kwargs:
9✔
652
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
653

654
    # Search for equilibrium points
655
    equilpts = _find_equilpts(sys, points, params=params)
9✔
656

657
    # Plot the equilibrium points
658
    out = []
9✔
659
    for xeq in equilpts:
9✔
660
        with plt.rc_context(rcParams):
9✔
661
            out += ax.plot(xeq[0], xeq[1], marker='o', color=color)
9✔
662
    return out
9✔
663

664

665
def separatrices(
9✔
666
        sys, pointdata, timedata=None, gridspec=None, ax=None,
667
        _check_kwargs=True, suppress_warnings=False, **kwargs):
668
    """Plot separatrices in the phase plane.
669

670
    This function plots separatrices for a two-dimensional state space
671
    system.
672

673
    Parameters
674
    ----------
675
    sys : NonlinearIOSystem or callable(t, x, ...)
676
        I/O system or function used to generate phase plane data. If a
677
        function is given, the remaining arguments are drawn from the
678
        `params` keyword.
679
    pointdata : list or 2D array
680
        List of the form [xmin, xmax, ymin, ymax] describing the
681
        boundaries of the phase plot or an array of shape (N, 2)
682
        giving points of at which to plot the vector field.
683
    timedata : int or list of int
684
        Time to simulate each streamline.  If a list is given, a different
685
        time can be used for each initial condition in `pointdata`.
686
    gridtype : str, optional
687
        The type of grid to use for generating initial conditions:
688
        'meshgrid' (default) generates a mesh of initial conditions within
689
        the specified boundaries, 'boxgrid' generates initial conditions
690
        along the edges of the boundary, 'circlegrid' generates a circle of
691
        initial conditions around each point in point data.
692
    gridspec : list, optional
693
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
694
        size of the grid in the x and y axes on which to generate points.
695
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
696
        specifying the radius and number of points around each point in the
697
        `pointdata` array.
698
    params : dict or list, optional
699
        Parameters to pass to system. For an I/O system, `params` should be
700
        a dict of parameters and values. For a callable, `params` should be
701
        dict with key 'args' and value given by a tuple (passed to callable).
702
    color : matplotlib color spec, optional
703
        Plot the separatrics in the given color.  If a single color
704
        specification is given, this is used for both stable and unstable
705
        separatrices.  If a tuple is given, the first element is used as
706
        the color specification for stable separatrices and the second
707
        elmeent for unstable separatrices.
708
    ax : matplotlib.axes.Axes
709
        Use the given axes for the plot, otherwise use the current axes.
710

711
    Returns
712
    -------
713
    out : list of Line2D objects
714

715
    Other parameters
716
    ----------------
717
    rcParams : dict
718
        Override the default parameters used for generating plots.
719
        Default is set by config.default['ctrlplot.rcParams'].
720
    suppress_warnings : bool, optional
721
        If set to `True`, suppress warning messages in generating trajectories.
722

723
    """
724
    # Process keywords
725
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
726

727
    # Get system parameters
728
    params = kwargs.pop('params', None)
9✔
729

730
    # Create system from callable, if needed
731
    sys = _create_system(sys, params)
9✔
732

733
    # Parse the arrows keyword
734
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
735

736
    # Determine the initial states to use in searching for equilibrium points
737
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
738
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
739

740
    # Find the equilibrium points
741
    equilpts = _find_equilpts(sys, points, params=params)
9✔
742
    radius = config._get_param('phaseplot', 'separatrices_radius')
9✔
743

744
    # Create axis if needed
745
    if ax is None:
9✔
746
        ax = plt.gca()
9✔
747

748
    # Set the axis limits
749
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
750

751
    # Figure out the color to use for stable, unstable subspaces
752
    color = _get_color(kwargs)
9✔
753
    match color:
9✔
754
        case None:
9✔
755
            stable_color = 'r'
9✔
756
            unstable_color = 'b'
9✔
757
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
9✔
758
            pass
9✔
759
        case single_color:
9✔
760
            stable_color = unstable_color = color
9✔
761

762
    # Make sure all keyword arguments were processed
763
    if _check_kwargs and kwargs:
9✔
764
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
765

766
    # Create a "reverse time" system to use for simulation
767
    revsys = NonlinearIOSystem(
9✔
768
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
769
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
770
        outputs=sys.noutputs, params=sys.params)
771

772
    # Plot separatrices by flowing backwards in time along eigenspaces
773
    out = []
9✔
774
    for i, xeq in enumerate(equilpts):
9✔
775
        # Plot the equilibrium points
776
        with plt.rc_context(rcParams):
9✔
777
            out += ax.plot(xeq[0], xeq[1], marker='o', color='k')
9✔
778

779
        # Figure out the linearization and eigenvectors
780
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
9✔
781

782
        # See if we have real eigenvalues (=> evecs are meaningful)
783
        if evals[0].imag > 0:
9✔
784
            continue
9✔
785

786
        # Create default list of time points
787
        if timedata is not None:
9✔
788
            timepts = _make_timepts(timedata, i)
9✔
789

790
        # Generate the traces
791
        for j, dir in enumerate(evecs.T):
9✔
792
            # Figure out time vector if not yet computed
793
            if timedata is None:
9✔
794
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
9✔
795
                timepts = np.linspace(0, timescale)
9✔
796

797
            # Run the trajectory starting in eigenvector directions
798
            for eps in [-radius, radius]:
9✔
799
                x0 = xeq + dir * eps
9✔
800
                if evals[j].real < 0:
9✔
801
                    traj = _create_trajectory(
9✔
802
                        sys, revsys, timepts, x0, params, 'reverse',
803
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
804
                        suppress_warnings=suppress_warnings)
805
                    color = stable_color
9✔
806
                    linestyle = '--'
9✔
807
                elif evals[j].real > 0:
9✔
808
                    traj = _create_trajectory(
9✔
809
                        sys, revsys, timepts, x0, params, 'forward',
810
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
811
                        suppress_warnings=suppress_warnings)
812
                    color = unstable_color
9✔
813
                    linestyle = '-'
9✔
814

815
                # Plot the trajectory (if there is one)
816
                if traj.shape[1] > 1:
9✔
817
                    with plt.rc_context(rcParams):
9✔
818
                        out += ax.plot(
9✔
819
                            traj[0], traj[1], color=color, linestyle=linestyle)
820

821
                    # Add arrows to the lines at specified intervals
822
                    with plt.rc_context(rcParams):
9✔
823
                        _add_arrows_to_line2D(
9✔
824
                            ax, out[-1], arrow_pos, arrowstyle=arrow_style,
825
                            dir=1)
826
    return out
9✔
827

828

829
#
830
# User accessible utility functions
831
#
832

833
# Utility function to generate boxgrid (in the form needed here)
834
def boxgrid(xvals, yvals):
9✔
835
    """Generate list of points along the edge of box.
836

837
    points = boxgrid(xvals, yvals) generates a list of points that
838
    corresponds to a grid given by the cross product of the x and y values.
839

840
    Parameters
841
    ----------
842
    xvals, yvals : 1D array-like
843
        Array of points defining the points on the lower and left edges of
844
        the box.
845

846
    Returns
847
    -------
848
    grid : 2D array
849
        Array with shape (p, 2) defining the points along the edges of the
850
        box, where p is the number of points around the edge.
851

852
    """
853
    return np.array(
9✔
854
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
855
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
856
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
857
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
858
    )
859

860

861
# Utility function to generate meshgrid (in the form needed here)
862
# TODO: add examples of using grid functions directly
863
def meshgrid(xvals, yvals):
9✔
864
    """Generate list of points forming a mesh.
865

866
    points = meshgrid(xvals, yvals) generates a list of points that
867
    corresponds to a grid given by the cross product of the x and y values.
868

869
    Parameters
870
    ----------
871
    xvals, yvals : 1D array-like
872
        Array of points defining the points on the lower and left edges of
873
        the box.
874

875
    Returns
876
    -------
877
    grid: 2D array
878
        Array of points with shape (n * m, 2) defining the mesh
879

880
    """
881
    xvals, yvals = np.meshgrid(xvals, yvals)
9✔
882
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
9✔
883
    grid[:, 0] = xvals.reshape(-1)
9✔
884
    grid[:, 1] = yvals.reshape(-1)
9✔
885

886
    return grid
9✔
887

888

889
# Utility function to generate circular grid
890
def circlegrid(centers, radius, num):
9✔
891
    """Generate list of points around a circle.
892

893
    points = circlegrid(centers, radius, num) generates a list of points
894
    that form a circle around a list of centers.
895

896
    Parameters
897
    ----------
898
    centers : 2D array-like
899
        Array of points with shape (p, 2) defining centers of the circles.
900
    radius : float
901
        Radius of the points to be generated around each center.
902
    num : int
903
        Number of points to generate around the circle.
904

905
    Returns
906
    -------
907
    grid: 2D array
908
        Array of points with shape (p * num, 2) defining the circles.
909

910
    """
911
    centers = np.atleast_2d(np.array(centers))
9✔
912
    grid = np.zeros((centers.shape[0] * num, 2))
9✔
913
    for i, center in enumerate(centers):
9✔
914
        grid[i * num: (i + 1) * num, :] = center + np.array([
9✔
915
            [radius * math.cos(theta), radius * math.sin(theta)] for
916
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
917
    return grid
9✔
918

919
#
920
# Internal utility functions
921
#
922

923
# Create a system from a callable
924
def _create_system(sys, params):
9✔
925
    if isinstance(sys, NonlinearIOSystem):
9✔
926
        if sys.nstates != 2:
9✔
927
            raise ValueError("system must be planar")
9✔
928
        return sys
9✔
929

930
    # Make sure that if params is present, it has 'args' key
931
    if params and not params.get('args', None):
9✔
932
        raise ValueError("params must be dict with key 'args'")
9✔
933

934
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
9✔
935
    _output = lambda t, x, u, params: np.array([])
9✔
936
    return NonlinearIOSystem(
9✔
937
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
938

939
# Set axis limits for the plot
940
def _set_axis_limits(ax, pointdata):
9✔
941
    # Get the current axis limits
942
    if ax.lines:
9✔
943
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
9✔
944
    else:
945
        # Nothing on the plot => always use new limits
946
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
9✔
947

948
    # Short utility function for updating axis limits
949
    def _update_limits(cur, new):
9✔
950
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
9✔
951

952
    # If we were passed a box, use that to update the limits
953
    if isinstance(pointdata, list) and len(pointdata) == 4:
9✔
954
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
9✔
955
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
9✔
956

957
    elif isinstance(pointdata, np.ndarray):
9✔
958
        pointdata = np.atleast_2d(pointdata)
9✔
959
        xlim = _update_limits(
9✔
960
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
961
        ylim = _update_limits(
9✔
962
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
963

964
    # Keep track of the largest dimension on the plot
965
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
9✔
966

967
    # Set the new limits
968
    ax.autoscale(enable=True, axis='x', tight=True)
9✔
969
    ax.autoscale(enable=True, axis='y', tight=True)
9✔
970
    ax.set_xlim(xlim)
9✔
971
    ax.set_ylim(ylim)
9✔
972

973
    return xlim, ylim, maxlim
9✔
974

975

976
# Find equilibrium points
977
def _find_equilpts(sys, points, params=None):
9✔
978
    equilpts = []
9✔
979
    for i, x0 in enumerate(points):
9✔
980
        # Look for an equilibrium point near this point
981
        xeq, ueq = find_operating_point(sys, x0, 0, params=params)
9✔
982

983
        if xeq is None:
9✔
984
            continue            # didn't find anything
9✔
985

986
        # See if we have already found this point
987
        seen = False
9✔
988
        for x in equilpts:
9✔
989
            if np.allclose(np.array(x), xeq):
9✔
990
                seen = True
9✔
991
        if seen:
9✔
992
            continue
9✔
993

994
        # Save a new point
995
        equilpts += [xeq.tolist()]
9✔
996

997
    return equilpts
9✔
998

999

1000
def _make_points(pointdata, gridspec, gridtype):
9✔
1001
    # Check to see what type of data we got
1002
    if isinstance(pointdata, np.ndarray) and gridtype is None:
9✔
1003
        pointdata = np.atleast_2d(pointdata)
9✔
1004
        if pointdata.shape[1] == 2:
9✔
1005
            # Given a list of points => no action required
1006
            return pointdata, None
9✔
1007

1008
    # Utility function to parse (and check) input arguments
1009
    def _parse_args(defsize):
9✔
1010
        if gridspec is None:
9✔
1011
            return defsize
9✔
1012

1013
        elif not isinstance(gridspec, (list, tuple)) or \
9✔
1014
             len(gridspec) != len(defsize):
1015
            raise ValueError("invalid grid specification")
9✔
1016

1017
        return gridspec
9✔
1018

1019
    # Generate points based on grid type
1020
    match gridtype:
9✔
1021
        case 'boxgrid' | None:
9✔
1022
            gridspec = _parse_args([6, 4])
9✔
1023
            points = boxgrid(
9✔
1024
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1025
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1026

1027
        case 'meshgrid':
9✔
1028
            gridspec = _parse_args([9, 6])
9✔
1029
            points = meshgrid(
9✔
1030
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1031
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1032

1033
        case 'circlegrid':
9✔
1034
            gridspec = _parse_args((0.5, 10))
9✔
1035
            if isinstance(pointdata, np.ndarray):
9✔
1036
                # Create circles around each point
1037
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
9✔
1038
            else:
1039
                # Create circle around center of the plot
1040
                points = circlegrid(
9✔
1041
                    np.array(
1042
                        [(pointdata[0] + pointdata[1]) / 2,
1043
                         (pointdata[0] + pointdata[1]) / 2]),
1044
                    gridspec[0], gridspec[1])
1045

1046
        case _:
9✔
1047
            raise ValueError(f"unknown grid type '{gridtype}'")
9✔
1048

1049
    return points, gridspec
9✔
1050

1051

1052
def _parse_arrow_keywords(kwargs):
9✔
1053
    # Get values for params (and pop from list to allow keyword use in plot)
1054
    # TODO: turn this into a utility function (shared with nyquist_plot?)
1055
    arrows = config._get_param(
9✔
1056
        'phaseplot', 'arrows', kwargs, None, pop=True)
1057
    arrow_size = config._get_param(
9✔
1058
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
1059
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
9✔
1060

1061
    # Parse the arrows keyword
1062
    if not arrows:
9✔
1063
        arrow_pos = []
×
1064
    elif isinstance(arrows, int):
9✔
1065
        N = arrows
9✔
1066
        # Space arrows out, starting midway along each "region"
1067
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
9✔
1068
    elif isinstance(arrows, (list, np.ndarray)):
×
1069
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
1070
    else:
1071
        raise ValueError("unknown or unsupported arrow location")
×
1072

1073
    # Set the arrow style
1074
    if arrow_style is None:
9✔
1075
        arrow_style = mpl.patches.ArrowStyle(
9✔
1076
            'simple', head_width=int(2 * arrow_size / 3),
1077
            head_length=arrow_size)
1078

1079
    return arrow_pos, arrow_style
9✔
1080

1081

1082
# TODO: move to ctrlplot?
1083
def _create_trajectory(
9✔
1084
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
1085
        gridtype=None, gridspec=None, xlim=None, ylim=None):
1086
    # Comput ethe forward trajectory
1087
    if dir == 'forward' or dir == 'both':
9✔
1088
        fwdresp = input_output_response(
9✔
1089
            sys, timepts, X0=X0, params=params, ignore_errors=True)
1090
        if not fwdresp.success and not suppress_warnings:
9✔
1091
            warnings.warn(f"{X0=}, {fwdresp.message}")
9✔
1092

1093
    # Compute the reverse trajectory
1094
    if dir == 'reverse' or dir == 'both':
9✔
1095
        revresp = input_output_response(
9✔
1096
            revsys, timepts, X0=X0, params=params, ignore_errors=True)
1097
        if not revresp.success and not suppress_warnings:
9✔
1098
            warnings.warn(f"{X0=}, {revresp.message}")
×
1099

1100
    # Create the trace to plot
1101
    if dir == 'forward':
9✔
1102
        traj = fwdresp.states
9✔
1103
    elif dir == 'reverse':
9✔
1104
        traj = revresp.states[:, ::-1]
9✔
1105
    elif dir == 'both':
9✔
1106
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
9✔
1107

1108
    # Remove points outside the window (keep first point beyond boundary)
1109
    inrange = np.asarray(
9✔
1110
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
1111
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
1112
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
9✔
1113
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
9✔
1114

1115
    return traj[:, inrange]
9✔
1116

1117

1118
def _make_timepts(timepts, i):
9✔
1119
    if timepts is None:
9✔
1120
        return np.linspace(0, 1)
9✔
1121
    elif isinstance(timepts, (int, float)):
9✔
1122
        return np.linspace(0, timepts)
9✔
1123
    elif timepts.ndim == 2:
×
1124
        return timepts[i]
×
1125
    return timepts
×
1126

1127

1128
#
1129
# Legacy phase plot function
1130
#
1131
# Author: Richard Murray
1132
# Date: 24 July 2011, converted from MATLAB version (2002); based on
1133
# a version by Kristi Morgansen
1134
#
1135
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
9✔
1136
               lingrid=None, lintime=None, logtime=None, timepts=None,
1137
               parms=None, params=(), tfirst=False, verbose=True):
1138

1139
    """(legacy) Phase plot for 2D dynamical systems.
1140

1141
    .. deprecated:: 0.10.1
1142
        This function is deprecated; use :func:`phase_plane_plot` instead.
1143

1144
    Produces a vector field or stream line plot for a planar system.  This
1145
    function has been replaced by the :func:`~control.phase_plane_map` and
1146
    :func:`~control.phase_plane_plot` functions.
1147

1148
    Call signatures:
1149
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
1150
      phase_plot(func, X, Y, scale, ...) - scale arrows
1151
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
1152
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
1153
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
1154
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
1155

1156
    Parameters
1157
    ----------
1158
    func : callable(x, t, ...)
1159
        Computes the time derivative of y (compatible with odeint).  The
1160
        function should be the same for as used for :mod:`scipy.integrate`.
1161
        Namely, it should be a function of the form dxdt = F(t, x) that
1162
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1163
        dimension 2.
1164
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1165
        Two 3-element sequences specifying x and y coordinates of a
1166
        grid.  These arguments are passed to linspace and meshgrid to
1167
        generate the points at which the vector field is plotted.  If
1168
        absent (or None), the vector field is not plotted.
1169
    scale: float, optional
1170
        Scale size of arrows; default = 1
1171
    X0: ndarray of initial conditions, optional
1172
        List of initial conditions from which streamlines are plotted.
1173
        Each initial condition should be a pair of numbers.
1174
    T: array-like or number, optional
1175
        Length of time to run simulations that generate streamlines.
1176
        If a single number, the same simulation time is used for all
1177
        initial conditions.  Otherwise, should be a list of length
1178
        len(X0) that gives the simulation time for each initial
1179
        condition.  Default value = 50.
1180
    lingrid : integer or 2-tuple of integers, optional
1181
        Argument is either N or (N, M).  If X0 is given and X, Y are missing,
1182
        a grid of arrows is produced using the limits of the initial
1183
        conditions, with N grid points in each dimension or N grid points in x
1184
        and M grid points in y.
1185
    lintime : integer or tuple (integer, float), optional
1186
        If a single integer N is given, draw N arrows using equally space time
1187
        points.  If a tuple (N, lambda) is given, draw N arrows using
1188
        exponential time constant lambda
1189
    timepts : array-like, optional
1190
        Draw arrows at the given list times [t1, t2, ...]
1191
    tfirst : bool, optional
1192
        If True, call `func` with signature `func(t, x, ...)`.
1193
    params: tuple, optional
1194
        List of parameters to pass to vector field: `func(x, t, *params)`
1195

1196
    See also
1197
    --------
1198
    box_grid : construct box-shaped grid of initial conditions
1199

1200
    """
1201
    # Generate a deprecation warning
1202
    warnings.warn(
9✔
1203
        "phase_plot() is deprecated; use phase_plane_plot() instead",
1204
        FutureWarning)
1205

1206
    #
1207
    # Figure out ranges for phase plot (argument processing)
1208
    #
1209
    #! TODO: need to add error checking to arguments
1210
    #! TODO: think through proper action if multiple options are given
1211
    #
1212
    autoFlag = False
9✔
1213
    logtimeFlag = False
9✔
1214
    timeptsFlag = False
9✔
1215
    Narrows = 0
9✔
1216

1217
    # Get parameters to pass to function
1218
    if parms:
9✔
1219
        warnings.warn(
9✔
1220
            f"keyword 'parms' is deprecated; use 'params'", FutureWarning)
1221
        if params:
9✔
1222
            raise ControlArgument(f"duplicate keywords 'parms' and 'params'")
×
1223
        else:
1224
            params = parms
9✔
1225

1226
    if lingrid is not None:
9✔
1227
        autoFlag = True
9✔
1228
        Narrows = lingrid
9✔
1229
        if (verbose):
9✔
1230
            print('Using auto arrows\n')
×
1231

1232
    elif logtime is not None:
9✔
1233
        logtimeFlag = True
9✔
1234
        Narrows = logtime[0]
9✔
1235
        timefactor = logtime[1]
9✔
1236
        if (verbose):
9✔
1237
            print('Using logtime arrows\n')
×
1238

1239
    elif timepts is not None:
9✔
1240
        timeptsFlag = True
9✔
1241
        Narrows = len(timepts)
9✔
1242

1243
    # Figure out the set of points for the quiver plot
1244
    #! TODO: Add sanity checks
1245
    elif X is not None and Y is not None:
9✔
1246
        x1, x2 = np.meshgrid(
9✔
1247
            np.linspace(X[0], X[1], X[2]),
1248
            np.linspace(Y[0], Y[1], Y[2]))
1249
        Narrows = len(x1)
9✔
1250

1251
    else:
1252
        # If we weren't given any grid points, don't plot arrows
1253
        Narrows = 0
9✔
1254

1255
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
9✔
1256
        # Now calculate the vector field at those points
1257
        (nr,nc) = x1.shape
9✔
1258
        dx = np.empty((nr, nc, 2))
9✔
1259
        for i in range(nr):
9✔
1260
            for j in range(nc):
9✔
1261
                if tfirst:
9✔
1262
                    dx[i, j, :] = np.squeeze(
×
1263
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1264
                else:
1265
                    dx[i, j, :] = np.squeeze(
9✔
1266
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1267

1268
        # Plot the quiver plot
1269
        #! TODO: figure out arguments to make arrows show up correctly
1270
        if scale is None:
9✔
1271
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1272
        elif (scale != 0):
9✔
1273
            #! TODO: optimize parameters for arrows
1274
            #! TODO: figure out arguments to make arrows show up correctly
1275
            xy = plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
9✔
1276
                            dx[:,:,1]*np.abs(scale), angles='xy')
1277
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1278

1279
        #! TODO: Tweak the shape of the plot
1280
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1281
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1282
        plt.xlabel('x1'); plt.ylabel('x2')
9✔
1283

1284
    # See if we should also generate the streamlines
1285
    if X0 is None or len(X0) == 0:
9✔
1286
        return
9✔
1287

1288
    # Convert initial conditions to a numpy array
1289
    X0 = np.array(X0)
9✔
1290
    (nr, nc) = np.shape(X0)
9✔
1291

1292
    # Generate some empty matrices to keep arrow information
1293
    x1 = np.empty((nr, Narrows))
9✔
1294
    x2 = np.empty((nr, Narrows))
9✔
1295
    dx = np.empty((nr, Narrows, 2))
9✔
1296

1297
    # See if we were passed a simulation time
1298
    if T is None:
9✔
1299
        T = 50
9✔
1300

1301
    # Parse the time we were passed
1302
    TSPAN = T
9✔
1303
    if isinstance(T, (int, float)):
9✔
1304
        TSPAN = np.linspace(0, T, 100)
9✔
1305

1306
    # Figure out the limits for the plot
1307
    if scale is None:
9✔
1308
        # Assume that the current axis are set as we want them
1309
        alim = plt.axis()
×
1310
        xmin = alim[0]; xmax = alim[1]
×
1311
        ymin = alim[2]; ymax = alim[3]
×
1312
    else:
1313
        # Use the maximum extent of all trajectories
1314
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
9✔
1315
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
9✔
1316

1317
    # Generate the streamlines for each initial condition
1318
    for i in range(nr):
9✔
1319
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
9✔
1320
        time = TSPAN
9✔
1321

1322
        plt.plot(state[:,0], state[:,1])
9✔
1323
        #! TODO: add back in colors for stream lines
1324
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1325
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1326

1327
        # Plot arrows if quiver parameters were 'auto'
1328
        if autoFlag or logtimeFlag or timeptsFlag:
9✔
1329
            # Compute the locations of the arrows
1330
            #! TODO: check this logic to make sure it works in python
1331
            for j in range(Narrows):
9✔
1332

1333
                # Figure out starting index; headless arrows start at 0
1334
                k = -1 if scale is None else 0
9✔
1335

1336
                # Figure out what time index to use for the next point
1337
                if autoFlag:
9✔
1338
                    # Use a linear scaling based on ODE time vector
1339
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1340
                elif logtimeFlag:
9✔
1341
                    # Use an exponential time vector
1342
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1343
                    tarr = _find(time < (j-k) / timefactor)
9✔
1344
                    tind = tarr[-1] if len(tarr) else 0
9✔
1345
                elif timeptsFlag:
9✔
1346
                    # Use specified time points
1347
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1348
                    tarr = _find(time < timepts[j])
9✔
1349
                    tind = tarr[-1] if len(tarr) else 0
9✔
1350

1351
                # For tailless arrows, skip the first point
1352
                if tind == 0 and scale is None:
9✔
1353
                    continue
×
1354

1355
                # Figure out the arrow at this point on the curve
1356
                x1[i,j] = state[tind, 0]
9✔
1357
                x2[i,j] = state[tind, 1]
9✔
1358

1359
                # Skip arrows outside of initial condition box
1360
                if (scale is not None or
9✔
1361
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1362
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1363
                    if tfirst:
9✔
1364
                        pass
×
1365
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1366
                    else:
1367
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
9✔
1368
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
9✔
1369
                else:
1370
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1371

1372
    # Set the plot shape before plotting arrows to avoid warping
1373
    # a=gca
1374
    # if (scale != None):
1375
    #     set(a,'DataAspectRatio', [1,1,1])
1376
    # if (xmin != xmax and ymin != ymax):
1377
    #     plt.axis([xmin, xmax, ymin, ymax])
1378
    # set(a, 'Box', 'on')
1379

1380
    # Plot arrows on the streamlines
1381
    if scale is None and Narrows > 0:
9✔
1382
        # Use a tailless arrow
1383
        #! TODO: figure out arguments to make arrows show up correctly
1384
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1385
    elif scale != 0 and Narrows > 0:
9✔
1386
        #! TODO: figure out arguments to make arrows show up correctly
1387
        xy = plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
9✔
1388
                        angles='xy')
1389
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1390
        # set(xy, 'AutoScale', 'off')
1391
        # set(xy, 'AutoScaleFactor', 0)
1392

1393
    if scale < 0:
9✔
1394
        bp = plt.plot(x1, x2, 'b.');        # add dots at base
×
1395
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1396

1397

1398
# Utility function for generating initial conditions around a box
1399
def box_grid(xlimp, ylimp):
9✔
1400
    """box_grid   generate list of points on edge of box
1401

1402
    .. deprecated:: 0.10.0
1403
        Use :func:`phaseplot.boxgrid` instead.
1404

1405
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1406
    list of points that correspond to a uniform grid at the end of the
1407
    box defined by the corners [xmin ymin] and [xmax ymax].
1408

1409
    """
1410

1411
    # Generate a deprecation warning
1412
    warnings.warn(
×
1413
        "box_grid() is deprecated; use phaseplot.boxgrid() instead",
1414
        FutureWarning)
1415

1416
    return boxgrid(
×
1417
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1418
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1419

1420

1421
# TODO: rename to something more useful (or remove??)
1422
def _find(condition):
9✔
1423
    """Returns indices where ravel(a) is true.
1424
    Private implementation of deprecated matplotlib.mlab.find
1425
    """
1426
    return np.nonzero(np.ravel(condition))[0]
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc