• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 13608802463

01 Mar 2025 09:05PM UTC coverage: 94.745% (+0.001%) from 94.744%
13608802463

Pull #1133

github

web-flow
Merge bb35a88eb into f6799ab8e
Pull Request #1133: Fix Latex not being rendered in HTML output in VSCode

9863 of 10410 relevant lines covered (94.75%)

8.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.38
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Initial author: Richard M. Murray
4
# Creation date: 24 July 2011, converted from MATLAB version (2002);
5
# based on an original version by Kristi Morgansen
6

7
"""Generate 2D phase portraits.
8

9
This module contains functions for generating 2D phase plots. The base
10
function for creating phase plane portraits is `~control.phase_plane_plot`,
11
which generates a phase plane portrait for a 2 state I/O system (with no
12
inputs). Utility functions are available to customize the individual
13
elements of a phase plane portrait.
14

15
The docstring examples assume the following import commands::
16

17
  >>> import numpy as np
18
  >>> import control as ct
19
  >>> import control.phaseplot as pp
20

21
"""
22

23
import math
9✔
24
import warnings
9✔
25

26
import matplotlib as mpl
9✔
27
import matplotlib.pyplot as plt
9✔
28
import numpy as np
9✔
29
from scipy.integrate import odeint
9✔
30

31
from . import config
9✔
32
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
9✔
33
    _process_ax_keyword, _update_plot_title
34
from .exception import ControlArgument
9✔
35
from .nlsys import NonlinearIOSystem, find_operating_point, \
9✔
36
    input_output_response
37

38
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
9✔
39

40
# Default values for module parameter variables
41
_phaseplot_defaults = {
9✔
42
    'phaseplot.arrows': 2,                  # number of arrows around curve
43
    'phaseplot.arrow_size': 8,              # pixel size for arrows
44
    'phaseplot.arrow_style': None,          # set arrow style
45
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
46
}
47

48

49
def phase_plane_plot(
9✔
50
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
51
        plot_streamlines=None, plot_vectorfield=None, plot_streamplot=None,
52
        plot_equilpoints=True, plot_separatrices=True, ax=None,
53
        suppress_warnings=False, title=None, **kwargs
54
):
55
    """Plot phase plane diagram.
56

57
    This function plots phase plane data, including vector fields, stream
58
    lines, equilibrium points, and contour curves.
59
    If none of plot_streamlines, plot_vectorfield, or plot_streamplot are
60
    set, then plot_streamplot is used by default.
61

62
    Parameters
63
    ----------
64
    sys : `NonlinearIOSystem` or callable(t, x, ...)
65
        I/O system or function used to generate phase plane data. If a
66
        function is given, the remaining arguments are drawn from the
67
        `params` keyword.
68
    pointdata : list or 2D array
69
        List of the form [xmin, xmax, ymin, ymax] describing the
70
        boundaries of the phase plot or an array of shape (N, 2)
71
        giving points of at which to plot the vector field.
72
    timedata : int or list of int
73
        Time to simulate each streamline.  If a list is given, a different
74
        time can be used for each initial condition in `pointdata`.
75
    gridtype : str, optional
76
        The type of grid to use for generating initial conditions:
77
        'meshgrid' (default) generates a mesh of initial conditions within
78
        the specified boundaries, 'boxgrid' generates initial conditions
79
        along the edges of the boundary, 'circlegrid' generates a circle of
80
        initial conditions around each point in point data.
81
    gridspec : list, optional
82
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
83
        size of the grid in the x and y axes on which to generate points.
84
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
85
        specifying the radius and number of points around each point in the
86
        `pointdata` array.
87
    params : dict, optional
88
        Parameters to pass to system. For an I/O system, `params` should be
89
        a dict of parameters and values. For a callable, `params` should be
90
        dict with key 'args' and value given by a tuple (passed to callable).
91
    color : matplotlib color spec, optional
92
        Plot all elements in the given color (use ``plot_<element>`` =
93
        {'color': c} to set the color in one element of the phase
94
        plot (equilpoints, separatrices, streamlines, etc).
95
    ax : `matplotlib.axes.Axes`, optional
96
        The matplotlib axes to draw the figure on.  If not specified and
97
        the current figure has a single axes, that axes is used.
98
        Otherwise, a new figure is created.
99

100
    Returns
101
    -------
102
    cplt : `ControlPlot` object
103
        Object containing the data that were plotted.  See `ControlPlot`
104
        for more detailed information.
105
    cplt.lines : array of list of `matplotlib.lines.Line2D`
106
        Array of list of `matplotlib.artist.Artist` objects:
107

108
            - lines[0] = list of Line2D objects (streamlines, separatrices).
109
            - lines[1] = Quiver object (vector field arrows).
110
            - lines[2] = list of Line2D objects (equilibrium points).
111
            - lines[3] = StreamplotSet object (lines with arrows).
112

113
    cplt.axes : 2D array of `matplotlib.axes.Axes`
114
        Axes for each subplot.
115
    cplt.figure : `matplotlib.figure.Figure`
116
        Figure containing the plot.
117

118
    Other Parameters
119
    ----------------
120
    arrows : int
121
        Set the number of arrows to plot along the streamlines. The default
122
        value can be set in `config.defaults['phaseplot.arrows']`.
123
    arrow_size : float
124
        Set the size of arrows to plot along the streamlines.  The default
125
        value can be set in `config.defaults['phaseplot.arrow_size']`.
126
    arrow_style : matplotlib patch
127
        Set the style of arrows to plot along the streamlines.  The default
128
        value can be set in `config.defaults['phaseplot.arrow_style']`.
129
    dir : str, optional
130
        Direction to draw streamlines: 'forward' to flow forward in time
131
        from the reference points, 'reverse' to flow backward in time, or
132
        'both' to flow both forward and backward.  The amount of time to
133
        simulate in each direction is given by the `timedata` argument.
134
    plot_streamlines : bool or dict, optional
135
        If True then plot streamlines based on the pointdata and gridtype.
136
        If set to a dict, pass on the key-value pairs in the dict as
137
        keywords to `streamlines`.
138
    plot_vectorfield : bool or dict, optional
139
        If True then plot the vector field based on the pointdata and
140
        gridtype.  If set to a dict, pass on the key-value pairs in the
141
        dict as keywords to `phaseplot.vectorfield`.
142
    plot_streamplot : bool or dict, optional
143
        If True then use `matplotlib.axes.Axes.streamplot` function
144
        to plot the streamlines.  If set to a dict, pass on the key-value
145
        pairs in the dict as keywords to `phaseplot.streamplot`.
146
    plot_equilpoints : bool or dict, optional
147
        If True (default) then plot equilibrium points based in the phase
148
        plot boundary. If set to a dict, pass on the key-value pairs in the
149
        dict as keywords to `phaseplot.equilpoints`.
150
    plot_separatrices : bool or dict, optional
151
        If True (default) then plot separatrices starting from each
152
        equilibrium point.  If set to a dict, pass on the key-value pairs
153
        in the dict as keywords to `phaseplot.separatrices`.
154
    rcParams : dict
155
        Override the default parameters used for generating plots.
156
        Default is set by `config.defaults['ctrlplot.rcParams']`.
157
    suppress_warnings : bool, optional
158
        If set to True, suppress warning messages in generating trajectories.
159
    title : str, optional
160
        Set the title of the plot.  Defaults to plot type and system name(s).
161

162
    Notes
163
    -----
164
    The default method for producing streamlines is determined based on which
165
    keywords are specified, with `plot_streamplot` serving as the generic
166
    default.  If any of the `arrows`, `arrow_size`, `arrow_style`, or `dir`
167
    keywords are used and neither `plot_streamlines` nor `plot_streamplot` is
168
    set, then `plot_streamlines` will be set to True.  If neither
169
    `plot_streamlines` nor `plot_vectorfield` set set to True, then
170
    `plot_streamplot` will be set to True.
171

172
    """
173
    # Check for legacy usage of plot_streamlines
174
    streamline_keywords = [
9✔
175
        'arrows', 'arrow_size', 'arrow_style', 'dir']
176
    if plot_streamlines is None:
9✔
177
        if any([kw in kwargs for kw in streamline_keywords]):
9✔
178
            warnings.warn(
×
179
                "detected streamline keywords; use plot_streamlines to set",
180
                FutureWarning)
181
            plot_streamlines = True
×
182
        if gridtype not in [None, 'meshgrid']:
9✔
183
            warnings.warn(
×
184
                "streamplots only support gridtype='meshgrid'; "
185
                "falling back to streamlines")
186
            plot_streamlines = True
×
187

188
    if plot_streamlines is None and plot_vectorfield is None \
9✔
189
       and plot_streamplot is None:
190
        plot_streamplot = True
9✔
191

192
    if plot_streamplot and not plot_streamlines and not plot_vectorfield:
9✔
193
        gridspec = gridspec or [25, 25]
9✔
194

195
    # Process arguments
196
    params = kwargs.get('params', None)
9✔
197
    sys = _create_system(sys, params)
9✔
198
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
9✔
199
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
200

201
    # Create axis if needed
202
    user_ax = ax
9✔
203
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
9✔
204

205
    # Create copy of kwargs for later checking to find unused arguments
206
    initial_kwargs = dict(kwargs)
9✔
207

208
    # Utility function to create keyword arguments
209
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
9✔
210
        new_kwargs = dict(global_kwargs)
9✔
211
        new_kwargs.update(other_kwargs)
9✔
212
        if isinstance(local_kwargs, dict):
9✔
213
            new_kwargs.update(local_kwargs)
9✔
214
        return new_kwargs
9✔
215

216
    # Create list for storing outputs
217
    out = np.array([[], None, None, None], dtype=object)
9✔
218

219
    # the maximum zorder of stramlines, vectorfield or streamplot
220
    flow_zorder = None
9✔
221

222
    # Plot out the main elements
223
    if plot_streamlines:
9✔
224
        kwargs_local = _create_kwargs(
9✔
225
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
226
            ax=ax)
227
        out[0] += streamlines(
9✔
228
            sys, pointdata, timedata, _check_kwargs=False,
229
            suppress_warnings=suppress_warnings, **kwargs_local)
230

231
        new_zorder = max(elem.get_zorder() for elem in out[0])
9✔
232
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder \
9✔
233
            else new_zorder
234

235
        # Get rid of keyword arguments handled by streamlines
236
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
9✔
237
                   'dir', 'params']:
238
            initial_kwargs.pop(kw, None)
9✔
239

240
    # Reset the gridspec for the remaining commands, if needed
241
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
9✔
242
        gridspec = None
×
243

244
    if plot_vectorfield:
9✔
245
        kwargs_local = _create_kwargs(
9✔
246
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
247
        out[1] = vectorfield(
9✔
248
            sys, pointdata, _check_kwargs=False, **kwargs_local)
249

250
        new_zorder = out[1].get_zorder()
9✔
251
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder \
9✔
252
            else new_zorder
253

254
        # Get rid of keyword arguments handled by vectorfield
255
        for kw in ['color', 'params']:
9✔
256
            initial_kwargs.pop(kw, None)
9✔
257

258
    if plot_streamplot:
9✔
259
        if gridtype not in [None, 'meshgrid']:
9✔
260
            raise ValueError(
9✔
261
                "gridtype must be 'meshgrid' when using streamplot")
262

263
        kwargs_local = _create_kwargs(
9✔
264
            kwargs, plot_streamplot, gridspec=gridspec, ax=ax)
265
        out[3] = streamplot(
9✔
266
            sys, pointdata, _check_kwargs=False, **kwargs_local)
267

268
        new_zorder = max(out[3].lines.get_zorder(), out[3].arrows.get_zorder())
9✔
269
        flow_zorder = max(flow_zorder, new_zorder) if flow_zorder \
9✔
270
            else new_zorder
271

272
        # Get rid of keyword arguments handled by streamplot
273
        for kw in ['color', 'params']:
9✔
274
            initial_kwargs.pop(kw, None)
9✔
275

276
    sep_zorder = flow_zorder + 1 if flow_zorder else None
9✔
277

278
    if plot_separatrices:
9✔
279
        kwargs_local = _create_kwargs(
9✔
280
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
281
        kwargs_local['zorder'] = kwargs_local.get('zorder', sep_zorder)
9✔
282
        out[0] += separatrices(
9✔
283
            sys, pointdata, _check_kwargs=False,  **kwargs_local)
284

285
        sep_zorder = max(elem.get_zorder() for elem in out[0]) if out[0] \
9✔
286
            else None
287

288
        # Get rid of keyword arguments handled by separatrices
289
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
9✔
290
            initial_kwargs.pop(kw, None)
9✔
291

292
    equil_zorder = sep_zorder + 1 if sep_zorder else None
9✔
293

294
    if plot_equilpoints:
9✔
295
        kwargs_local = _create_kwargs(
9✔
296
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
297
        kwargs_local['zorder'] = kwargs_local.get('zorder', equil_zorder)
9✔
298
        out[2] = equilpoints(
9✔
299
            sys, pointdata, _check_kwargs=False, **kwargs_local)
300

301
        # Get rid of keyword arguments handled by equilpoints
302
        for kw in ['params']:
9✔
303
            initial_kwargs.pop(kw, None)
9✔
304

305
    # Make sure all keyword arguments were used
306
    if initial_kwargs:
9✔
307
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
9✔
308

309
    if user_ax is None:
9✔
310
        if title is None:
9✔
311
            title = f"Phase portrait for {sys.name}"
9✔
312
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
9✔
313
        ax.set_xlabel(sys.state_labels[0])
9✔
314
        ax.set_ylabel(sys.state_labels[1])
9✔
315
        plt.tight_layout()
9✔
316

317
    return ControlPlot(out, ax, fig)
9✔
318

319

320
def vectorfield(
9✔
321
        sys, pointdata, gridspec=None, zorder=None, ax=None,
322
        suppress_warnings=False, _check_kwargs=True, **kwargs):
323
    """Plot a vector field in the phase plane.
324

325
    This function plots a vector field for a two-dimensional state
326
    space system.
327

328
    Parameters
329
    ----------
330
    sys : `NonlinearIOSystem` or callable(t, x, ...)
331
        I/O system or function used to generate phase plane data.  If a
332
        function is given, the remaining arguments are drawn from the
333
        `params` keyword.
334
    pointdata : list or 2D array
335
        List of the form [xmin, xmax, ymin, ymax] describing the
336
        boundaries of the phase plot or an array of shape (N, 2)
337
        giving points of at which to plot the vector field.
338
    gridtype : str, optional
339
        The type of grid to use for generating initial conditions:
340
        'meshgrid' (default) generates a mesh of initial conditions within
341
        the specified boundaries, 'boxgrid' generates initial conditions
342
        along the edges of the boundary, 'circlegrid' generates a circle of
343
        initial conditions around each point in point data.
344
    gridspec : list, optional
345
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
346
        size of the grid in the x and y axes on which to generate points.
347
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
348
        specifying the radius and number of points around each point in the
349
        `pointdata` array.
350
    params : dict or list, optional
351
        Parameters to pass to system. For an I/O system, `params` should be
352
        a dict of parameters and values. For a callable, `params` should be
353
        dict with key 'args' and value given by a tuple (passed to callable).
354
    color : matplotlib color spec, optional
355
        Plot the vector field in the given color.
356
    ax : `matplotlib.axes.Axes`, optional
357
        Use the given axes for the plot, otherwise use the current axes.
358

359
    Returns
360
    -------
361
    out : Quiver
362

363
    Other Parameters
364
    ----------------
365
    rcParams : dict
366
        Override the default parameters used for generating plots.
367
        Default is set by `config.defaults['ctrlplot.rcParams']`.
368
    suppress_warnings : bool, optional
369
        If set to True, suppress warning messages in generating trajectories.
370
    zorder : float, optional
371
        Set the zorder for the vectorfield.  In not specified, it will be
372
        automatically chosen by `matplotlib.axes.Axes.quiver`.
373

374
    """
375
    # Process keywords
376
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
377

378
    # Get system parameters
379
    params = kwargs.pop('params', None)
9✔
380

381
    # Create system from callable, if needed
382
    sys = _create_system(sys, params)
9✔
383

384
    # Determine the points on which to generate the vector field
385
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
386

387
    # Create axis if needed
388
    if ax is None:
9✔
389
        ax = plt.gca()
9✔
390

391
    # Set the plotting limits
392
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
393

394
    # Figure out the color to use
395
    color = _get_color(kwargs, ax=ax)
9✔
396

397
    # Make sure all keyword arguments were processed
398
    if _check_kwargs and kwargs:
9✔
399
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
400

401
    # Generate phase plane (quiver) data
402
    vfdata = np.zeros((points.shape[0], 4))
9✔
403
    sys._update_params(params)
9✔
404
    for i, x in enumerate(points):
9✔
405
        vfdata[i, :2] = x
9✔
406
        vfdata[i, 2:] = sys._rhs(0, x, np.zeros(sys.ninputs))
9✔
407

408
    with plt.rc_context(rcParams):
9✔
409
        out = ax.quiver(
9✔
410
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
411
            angles='xy', color=color, zorder=zorder)
412

413
    return out
9✔
414

415

416
def streamplot(
9✔
417
        sys, pointdata, gridspec=None, zorder=None, ax=None, vary_color=False,
418
        vary_linewidth=False, cmap=None, norm=None, suppress_warnings=False,
419
        _check_kwargs=True, **kwargs):
420
    """Plot streamlines in the phase plane.
421

422
    This function plots the streamlines for a two-dimensional state
423
    space system using the `matplotlib.axes.Axes.streamplot` function.
424

425
    Parameters
426
    ----------
427
    sys : `NonlinearIOSystem` or callable(t, x, ...)
428
        I/O system or function used to generate phase plane data.  If a
429
        function is given, the remaining arguments are drawn from the
430
        `params` keyword.
431
    pointdata : list or 2D array
432
        List of the form [xmin, xmax, ymin, ymax] describing the
433
        boundaries of the phase plot.
434
    gridspec : list, optional
435
        Specifies the size of the grid in the x and y axes on which to
436
        generate points.
437
    params : dict or list, optional
438
        Parameters to pass to system. For an I/O system, `params` should be
439
        a dict of parameters and values. For a callable, `params` should be
440
        dict with key 'args' and value given by a tuple (passed to callable).
441
    color : matplotlib color spec, optional
442
        Plot the vector field in the given color.
443
    ax : `matplotlib.axes.Axes`, optional
444
        Use the given axes for the plot, otherwise use the current axes.
445

446
    Returns
447
    -------
448
    out : StreamplotSet
449
        Containter object with lines and arrows contained in the
450
        streamplot. See `matplotlib.axes.Axes.streamplot` for details.
451

452
    Other Parameters
453
    ----------------
454
    cmap : str or Colormap, optional
455
        Colormap to use for varying the color of the streamlines.
456
    norm : `matplotlib.colors.Normalize`, optional
457
        Normalization map to use for scaling the colormap and linewidths.
458
    rcParams : dict
459
        Override the default parameters used for generating plots.
460
        Default is set by `config.default['ctrlplot.rcParams']`.
461
    suppress_warnings : bool, optional
462
        If set to True, suppress warning messages in generating trajectories.
463
    vary_color : bool, optional
464
        If set to True, vary the color of the streamlines based on the
465
        magnitude of the vector field.
466
    vary_linewidth : bool, optional.
467
        If set to True, vary the linewidth of the streamlines based on the
468
        magnitude of the vector field.
469
    zorder : float, optional
470
        Set the zorder for the streamlines.  In not specified, it will be
471
        automatically chosen by `matplotlib.axes.Axes.streamplot`.
472

473
    """
474
    # Process keywords
475
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
476

477
    # Get system parameters
478
    params = kwargs.pop('params', None)
9✔
479

480
    # Create system from callable, if needed
481
    sys = _create_system(sys, params)
9✔
482

483
    # Determine the points on which to generate the streamplot field
484
    points, gridspec = _make_points(pointdata, gridspec, 'meshgrid')
9✔
485
    grid_arr_shape = gridspec[::-1]
9✔
486
    xs = points[:, 0].reshape(grid_arr_shape)
9✔
487
    ys = points[:, 1].reshape(grid_arr_shape)
9✔
488

489
    # Create axis if needed
490
    if ax is None:
9✔
491
        ax = plt.gca()
9✔
492

493
    # Set the plotting limits
494
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
495

496
    # Figure out the color to use
497
    color = _get_color(kwargs, ax=ax)
9✔
498

499
    # Make sure all keyword arguments were processed
500
    if _check_kwargs and kwargs:
9✔
501
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
502

503
    # Generate phase plane (quiver) data
504
    sys._update_params(params)
9✔
505
    us_flat, vs_flat = np.transpose(
9✔
506
        [sys._rhs(0, x, np.zeros(sys.ninputs)) for x in points])
507
    us, vs = us_flat.reshape(grid_arr_shape), vs_flat.reshape(grid_arr_shape)
9✔
508

509
    magnitudes = np.linalg.norm([us, vs], axis=0)
9✔
510
    norm = norm or mpl.colors.Normalize()
9✔
511
    normalized = norm(magnitudes)
9✔
512
    cmap = plt.get_cmap(cmap)
9✔
513

514
    with plt.rc_context(rcParams):
9✔
515
        default_lw = plt.rcParams['lines.linewidth']
9✔
516
        min_lw, max_lw = 0.25*default_lw, 2*default_lw
9✔
517
        linewidths = normalized * (max_lw - min_lw) + min_lw \
9✔
518
            if vary_linewidth else None
519
        color = magnitudes if vary_color else color
9✔
520

521
        out = ax.streamplot(
9✔
522
            xs, ys, us, vs, color=color, linewidth=linewidths, cmap=cmap,
523
            norm=norm, zorder=zorder)
524

525
    return out
9✔
526

527

528
def streamlines(
9✔
529
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
530
        zorder=None, ax=None, _check_kwargs=True, suppress_warnings=False,
531
        **kwargs):
532
    """Plot stream lines in the phase plane.
533

534
    This function plots stream lines for a two-dimensional state space
535
    system.
536

537
    Parameters
538
    ----------
539
    sys : `NonlinearIOSystem` or callable(t, x, ...)
540
        I/O system or function used to generate phase plane data.  If a
541
        function is given, the remaining arguments are drawn from the
542
        `params` keyword.
543
    pointdata : list or 2D array
544
        List of the form [xmin, xmax, ymin, ymax] describing the
545
        boundaries of the phase plot or an array of shape (N, 2)
546
        giving points of at which to plot the vector field.
547
    timedata : int or list of int
548
        Time to simulate each streamline.  If a list is given, a different
549
        time can be used for each initial condition in `pointdata`.
550
    gridtype : str, optional
551
        The type of grid to use for generating initial conditions:
552
        'meshgrid' (default) generates a mesh of initial conditions within
553
        the specified boundaries, 'boxgrid' generates initial conditions
554
        along the edges of the boundary, 'circlegrid' generates a circle of
555
        initial conditions around each point in point data.
556
    gridspec : list, optional
557
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
558
        size of the grid in the x and y axes on which to generate points.
559
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
560
        specifying the radius and number of points around each point in the
561
        `pointdata` array.
562
    dir : str, optional
563
        Direction to draw streamlines: 'forward' to flow forward in time
564
        from the reference points, 'reverse' to flow backward in time, or
565
        'both' to flow both forward and backward.  The amount of time to
566
        simulate in each direction is given by the `timedata` argument.
567
    params : dict or list, optional
568
        Parameters to pass to system. For an I/O system, `params` should be
569
        a dict of parameters and values. For a callable, `params` should be
570
        dict with key 'args' and value given by a tuple (passed to callable).
571
    color : str
572
        Plot the streamlines in the given color.
573
    ax : `matplotlib.axes.Axes`, optional
574
        Use the given axes for the plot, otherwise use the current axes.
575

576
    Returns
577
    -------
578
    out : list of Line2D objects
579

580
    Other Parameters
581
    ----------------
582
    arrows : int
583
        Set the number of arrows to plot along the streamlines. The default
584
        value can be set in `config.defaults['phaseplot.arrows']`.
585
    arrow_size : float
586
        Set the size of arrows to plot along the streamlines.  The default
587
        value can be set in `config.defaults['phaseplot.arrow_size']`.
588
    arrow_style : matplotlib patch
589
        Set the style of arrows to plot along the streamlines.  The default
590
        value can be set in `config.defaults['phaseplot.arrow_style']`.
591
    rcParams : dict
592
        Override the default parameters used for generating plots.
593
        Default is set by `config.defaults['ctrlplot.rcParams']`.
594
    suppress_warnings : bool, optional
595
        If set to True, suppress warning messages in generating trajectories.
596
    zorder : float, optional
597
        Set the zorder for the streamlines.  In not specified, it will be
598
        automatically chosen by `matplotlib.axes.Axes.plot`.
599

600
    """
601
    # Process keywords
602
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
603

604
    # Get system parameters
605
    params = kwargs.pop('params', None)
9✔
606

607
    # Create system from callable, if needed
608
    sys = _create_system(sys, params)
9✔
609

610
    # Parse the arrows keyword
611
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
612

613
    # Determine the points on which to generate the streamlines
614
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
9✔
615
    if dir is None:
9✔
616
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
9✔
617

618
    # Create axis if needed
619
    if ax is None:
9✔
620
        ax = plt.gca()
9✔
621

622
    # Set the axis limits
623
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
624

625
    # Figure out the color to use
626
    color = _get_color(kwargs, ax=ax)
9✔
627

628
    # Make sure all keyword arguments were processed
629
    if _check_kwargs and kwargs:
9✔
630
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
631

632
    # Create reverse time system, if needed
633
    if dir != 'forward':
9✔
634
        revsys = NonlinearIOSystem(
9✔
635
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
636
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
637
            outputs=sys.noutputs, params=sys.params)
638
    else:
639
        revsys = None
9✔
640

641
    # Generate phase plane (streamline) data
642
    out = []
9✔
643
    for i, X0 in enumerate(points):
9✔
644
        # Create the trajectory for this point
645
        timepts = _make_timepts(timedata, i)
9✔
646
        traj = _create_trajectory(
9✔
647
            sys, revsys, timepts, X0, params, dir,
648
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
649
            suppress_warnings=suppress_warnings)
650

651
        # Plot the trajectory (if there is one)
652
        if traj.shape[1] > 1:
9✔
653
            with plt.rc_context(rcParams):
9✔
654
                out += ax.plot(traj[0], traj[1], color=color, zorder=zorder)
9✔
655

656
                # Add arrows to the lines at specified intervals
657
                _add_arrows_to_line2D(
9✔
658
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
659
    return out
9✔
660

661

662
def equilpoints(
9✔
663
        sys, pointdata, gridspec=None, color='k', zorder=None, ax=None,
664
        _check_kwargs=True, **kwargs):
665
    """Plot equilibrium points in the phase plane.
666

667
    This function plots the equilibrium points for a planar dynamical system.
668

669
    Parameters
670
    ----------
671
    sys : `NonlinearIOSystem` or callable(t, x, ...)
672
        I/O system or function used to generate phase plane data. If a
673
        function is given, the remaining arguments are drawn from the
674
        `params` keyword.
675
    pointdata : list or 2D array
676
        List of the form [xmin, xmax, ymin, ymax] describing the
677
        boundaries of the phase plot or an array of shape (N, 2)
678
        giving points of at which to plot the vector field.
679
    gridtype : str, optional
680
        The type of grid to use for generating initial conditions:
681
        'meshgrid' (default) generates a mesh of initial conditions within
682
        the specified boundaries, 'boxgrid' generates initial conditions
683
        along the edges of the boundary, 'circlegrid' generates a circle of
684
        initial conditions around each point in point data.
685
    gridspec : list, optional
686
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
687
        size of the grid in the x and y axes on which to generate points.
688
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
689
        specifying the radius and number of points around each point in the
690
        `pointdata` array.
691
    params : dict or list, optional
692
        Parameters to pass to system. For an I/O system, `params` should be
693
        a dict of parameters and values. For a callable, `params` should be
694
        dict with key 'args' and value given by a tuple (passed to callable).
695
    color : str
696
        Plot the equilibrium points in the given color.
697
    ax : `matplotlib.axes.Axes`, optional
698
        Use the given axes for the plot, otherwise use the current axes.
699

700
    Returns
701
    -------
702
    out : list of Line2D objects
703

704
    Other Parameters
705
    ----------------
706
    rcParams : dict
707
        Override the default parameters used for generating plots.
708
        Default is set by `config.defaults['ctrlplot.rcParams']`.
709
    zorder : float, optional
710
        Set the zorder for the equilibrium points.  In not specified, it will
711
        be automatically chosen by `matplotlib.axes.Axes.plot`.
712

713
    """
714
    # Process keywords
715
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
716

717
    # Get system parameters
718
    params = kwargs.pop('params', None)
9✔
719

720
    # Create system from callable, if needed
721
    sys = _create_system(sys, params)
9✔
722

723
    # Create axis if needed
724
    if ax is None:
9✔
725
        ax = plt.gca()
9✔
726

727
    # Set the axis limits
728
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
729

730
    # Determine the points on which to generate the vector field
731
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
732
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
733

734
    # Make sure all keyword arguments were processed
735
    if _check_kwargs and kwargs:
9✔
736
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
737

738
    # Search for equilibrium points
739
    equilpts = _find_equilpts(sys, points, params=params)
9✔
740

741
    # Plot the equilibrium points
742
    out = []
9✔
743
    for xeq in equilpts:
9✔
744
        with plt.rc_context(rcParams):
9✔
745
            out += ax.plot(
9✔
746
                xeq[0], xeq[1], marker='o', color=color, zorder=zorder)
747
    return out
9✔
748

749

750
def separatrices(
9✔
751
        sys, pointdata, timedata=None, gridspec=None, zorder=None, ax=None,
752
        _check_kwargs=True, suppress_warnings=False, **kwargs):
753
    """Plot separatrices in the phase plane.
754

755
    This function plots separatrices for a two-dimensional state space
756
    system.
757

758
    Parameters
759
    ----------
760
    sys : `NonlinearIOSystem` or callable(t, x, ...)
761
        I/O system or function used to generate phase plane data. If a
762
        function is given, the remaining arguments are drawn from the
763
        `params` keyword.
764
    pointdata : list or 2D array
765
        List of the form [xmin, xmax, ymin, ymax] describing the
766
        boundaries of the phase plot or an array of shape (N, 2)
767
        giving points of at which to plot the vector field.
768
    timedata : int or list of int
769
        Time to simulate each streamline.  If a list is given, a different
770
        time can be used for each initial condition in `pointdata`.
771
    gridtype : str, optional
772
        The type of grid to use for generating initial conditions:
773
        'meshgrid' (default) generates a mesh of initial conditions within
774
        the specified boundaries, 'boxgrid' generates initial conditions
775
        along the edges of the boundary, 'circlegrid' generates a circle of
776
        initial conditions around each point in point data.
777
    gridspec : list, optional
778
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
779
        size of the grid in the x and y axes on which to generate points.
780
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
781
        specifying the radius and number of points around each point in the
782
        `pointdata` array.
783
    params : dict or list, optional
784
        Parameters to pass to system. For an I/O system, `params` should be
785
        a dict of parameters and values. For a callable, `params` should be
786
        dict with key 'args' and value given by a tuple (passed to callable).
787
    color : matplotlib color spec, optional
788
        Plot the separatrices in the given color.  If a single color
789
        specification is given, this is used for both stable and unstable
790
        separatrices.  If a tuple is given, the first element is used as
791
        the color specification for stable separatrices and the second
792
        element for unstable separatrices.
793
    ax : `matplotlib.axes.Axes`, optional
794
        Use the given axes for the plot, otherwise use the current axes.
795

796
    Returns
797
    -------
798
    out : list of Line2D objects
799

800
    Other Parameters
801
    ----------------
802
    rcParams : dict
803
        Override the default parameters used for generating plots.
804
        Default is set by `config.defaults['ctrlplot.rcParams']`.
805
    suppress_warnings : bool, optional
806
        If set to True, suppress warning messages in generating trajectories.
807
    zorder : float, optional
808
        Set the zorder for the separatrices.  In not specified, it will be
809
        automatically chosen by `matplotlib.axes.Axes.plot`.
810

811
    Notes
812
    -----
813
    The value of `config.defaults['separatrices_radius']` is used to set the
814
    offset from the equilibrium point to the starting point of the separatix
815
    traces, in the direction of the eigenvectors evaluated at that
816
    equilibrium point.
817

818
    """
819
    # Process keywords
820
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
821

822
    # Get system parameters
823
    params = kwargs.pop('params', None)
9✔
824

825
    # Create system from callable, if needed
826
    sys = _create_system(sys, params)
9✔
827

828
    # Parse the arrows keyword
829
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
830

831
    # Determine the initial states to use in searching for equilibrium points
832
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
833
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
834

835
    # Find the equilibrium points
836
    equilpts = _find_equilpts(sys, points, params=params)
9✔
837
    radius = config._get_param('phaseplot', 'separatrices_radius')
9✔
838

839
    # Create axis if needed
840
    if ax is None:
9✔
841
        ax = plt.gca()
9✔
842

843
    # Set the axis limits
844
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
845

846
    # Figure out the color to use for stable, unstable subspaces
847
    color = _get_color(kwargs)
9✔
848
    match color:
9✔
849
        case None:
9✔
850
            stable_color = 'r'
9✔
851
            unstable_color = 'b'
9✔
852
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
9✔
853
            pass
9✔
854
        case single_color:
9✔
855
            stable_color = unstable_color = single_color
9✔
856

857
    # Make sure all keyword arguments were processed
858
    if _check_kwargs and kwargs:
9✔
859
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
860

861
    # Create a "reverse time" system to use for simulation
862
    revsys = NonlinearIOSystem(
9✔
863
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
864
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
865
        outputs=sys.noutputs, params=sys.params)
866

867
    # Plot separatrices by flowing backwards in time along eigenspaces
868
    out = []
9✔
869
    for i, xeq in enumerate(equilpts):
9✔
870
        # Figure out the linearization and eigenvectors
871
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
9✔
872

873
        # See if we have real eigenvalues (=> evecs are meaningful)
874
        if evals[0].imag > 0:
9✔
875
            continue
9✔
876

877
        # Create default list of time points
878
        if timedata is not None:
9✔
879
            timepts = _make_timepts(timedata, i)
9✔
880

881
        # Generate the traces
882
        for j, dir in enumerate(evecs.T):
9✔
883
            # Figure out time vector if not yet computed
884
            if timedata is None:
9✔
885
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
9✔
886
                timepts = np.linspace(0, timescale)
9✔
887

888
            # Run the trajectory starting in eigenvector directions
889
            for eps in [-radius, radius]:
9✔
890
                x0 = xeq + dir * eps
9✔
891
                if evals[j].real < 0:
9✔
892
                    traj = _create_trajectory(
9✔
893
                        sys, revsys, timepts, x0, params, 'reverse',
894
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
895
                        suppress_warnings=suppress_warnings)
896
                    color = stable_color
9✔
897
                    linestyle = '--'
9✔
898
                elif evals[j].real > 0:
9✔
899
                    traj = _create_trajectory(
9✔
900
                        sys, revsys, timepts, x0, params, 'forward',
901
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
902
                        suppress_warnings=suppress_warnings)
903
                    color = unstable_color
9✔
904
                    linestyle = '-'
9✔
905

906
                # Plot the trajectory (if there is one)
907
                if traj.shape[1] > 1:
9✔
908
                    with plt.rc_context(rcParams):
9✔
909
                        out += ax.plot(
9✔
910
                            traj[0], traj[1], color=color,
911
                            linestyle=linestyle, zorder=zorder)
912

913
                    # Add arrows to the lines at specified intervals
914
                    with plt.rc_context(rcParams):
9✔
915
                        _add_arrows_to_line2D(
9✔
916
                            ax, out[-1], arrow_pos, arrowstyle=arrow_style,
917
                            dir=1)
918
    return out
9✔
919

920

921
#
922
# User accessible utility functions
923
#
924

925
# Utility function to generate boxgrid (in the form needed here)
926
def boxgrid(xvals, yvals):
9✔
927
    """Generate list of points along the edge of box.
928

929
    points = boxgrid(xvals, yvals) generates a list of points that
930
    corresponds to a grid given by the cross product of the x and y values.
931

932
    Parameters
933
    ----------
934
    xvals, yvals : 1D array_like
935
        Array of points defining the points on the lower and left edges of
936
        the box.
937

938
    Returns
939
    -------
940
    grid : 2D array
941
        Array with shape (p, 2) defining the points along the edges of the
942
        box, where p is the number of points around the edge.
943

944
    """
945
    return np.array(
9✔
946
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
947
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
948
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
949
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
950
    )
951

952

953
# Utility function to generate meshgrid (in the form needed here)
954
# TODO: add examples of using grid functions directly
955
def meshgrid(xvals, yvals):
9✔
956
    """Generate list of points forming a mesh.
957

958
    points = meshgrid(xvals, yvals) generates a list of points that
959
    corresponds to a grid given by the cross product of the x and y values.
960

961
    Parameters
962
    ----------
963
    xvals, yvals : 1D array_like
964
        Array of points defining the points on the lower and left edges of
965
        the box.
966

967
    Returns
968
    -------
969
    grid : 2D array
970
        Array of points with shape (n * m, 2) defining the mesh.
971

972
    """
973
    xvals, yvals = np.meshgrid(xvals, yvals)
9✔
974
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
9✔
975
    grid[:, 0] = xvals.reshape(-1)
9✔
976
    grid[:, 1] = yvals.reshape(-1)
9✔
977

978
    return grid
9✔
979

980

981
# Utility function to generate circular grid
982
def circlegrid(centers, radius, num):
9✔
983
    """Generate list of points around a circle.
984

985
    points = circlegrid(centers, radius, num) generates a list of points
986
    that form a circle around a list of centers.
987

988
    Parameters
989
    ----------
990
    centers : 2D array_like
991
        Array of points with shape (p, 2) defining centers of the circles.
992
    radius : float
993
        Radius of the points to be generated around each center.
994
    num : int
995
        Number of points to generate around the circle.
996

997
    Returns
998
    -------
999
    grid : 2D array
1000
        Array of points with shape (p * num, 2) defining the circles.
1001

1002
    """
1003
    centers = np.atleast_2d(np.array(centers))
9✔
1004
    grid = np.zeros((centers.shape[0] * num, 2))
9✔
1005
    for i, center in enumerate(centers):
9✔
1006
        grid[i * num: (i + 1) * num, :] = center + np.array([
9✔
1007
            [radius * math.cos(theta), radius * math.sin(theta)] for
1008
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
1009
    return grid
9✔
1010

1011

1012
#
1013
# Internal utility functions
1014
#
1015

1016
# Create a system from a callable
1017
def _create_system(sys, params):
9✔
1018
    if isinstance(sys, NonlinearIOSystem):
9✔
1019
        if sys.nstates != 2:
9✔
1020
            raise ValueError("system must be planar")
9✔
1021
        return sys
9✔
1022

1023
    # Make sure that if params is present, it has 'args' key
1024
    if params and not params.get('args', None):
9✔
1025
        raise ValueError("params must be dict with key 'args'")
9✔
1026

1027
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
9✔
1028
    _output = lambda t, x, u, params: np.array([])
9✔
1029
    return NonlinearIOSystem(
9✔
1030
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
1031

1032

1033
# Set axis limits for the plot
1034
def _set_axis_limits(ax, pointdata):
9✔
1035
    # Get the current axis limits
1036
    if ax.lines:
9✔
1037
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
9✔
1038
    else:
1039
        # Nothing on the plot => always use new limits
1040
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
9✔
1041

1042
    # Short utility function for updating axis limits
1043
    def _update_limits(cur, new):
9✔
1044
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
9✔
1045

1046
    # If we were passed a box, use that to update the limits
1047
    if isinstance(pointdata, list) and len(pointdata) == 4:
9✔
1048
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
9✔
1049
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
9✔
1050

1051
    elif isinstance(pointdata, np.ndarray):
9✔
1052
        pointdata = np.atleast_2d(pointdata)
9✔
1053
        xlim = _update_limits(
9✔
1054
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
1055
        ylim = _update_limits(
9✔
1056
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
1057

1058
    # Keep track of the largest dimension on the plot
1059
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
9✔
1060

1061
    # Set the new limits
1062
    ax.autoscale(enable=True, axis='x', tight=True)
9✔
1063
    ax.autoscale(enable=True, axis='y', tight=True)
9✔
1064
    ax.set_xlim(xlim)
9✔
1065
    ax.set_ylim(ylim)
9✔
1066

1067
    return xlim, ylim, maxlim
9✔
1068

1069

1070
# Find equilibrium points
1071
def _find_equilpts(sys, points, params=None):
9✔
1072
    equilpts = []
9✔
1073
    for i, x0 in enumerate(points):
9✔
1074
        # Look for an equilibrium point near this point
1075
        xeq, ueq = find_operating_point(sys, x0, 0, params=params)
9✔
1076

1077
        if xeq is None:
9✔
1078
            continue            # didn't find anything
9✔
1079

1080
        # See if we have already found this point
1081
        seen = False
9✔
1082
        for x in equilpts:
9✔
1083
            if np.allclose(np.array(x), xeq):
9✔
1084
                seen = True
9✔
1085
        if seen:
9✔
1086
            continue
9✔
1087

1088
        # Save a new point
1089
        equilpts += [xeq.tolist()]
9✔
1090

1091
    return equilpts
9✔
1092

1093

1094
def _make_points(pointdata, gridspec, gridtype):
9✔
1095
    # Check to see what type of data we got
1096
    if isinstance(pointdata, np.ndarray) and gridtype is None:
9✔
1097
        pointdata = np.atleast_2d(pointdata)
9✔
1098
        if pointdata.shape[1] == 2:
9✔
1099
            # Given a list of points => no action required
1100
            return pointdata, None
9✔
1101

1102
    # Utility function to parse (and check) input arguments
1103
    def _parse_args(defsize):
9✔
1104
        if gridspec is None:
9✔
1105
            return defsize
9✔
1106

1107
        elif not isinstance(gridspec, (list, tuple)) or \
9✔
1108
             len(gridspec) != len(defsize):
1109
            raise ValueError("invalid grid specification")
9✔
1110

1111
        return gridspec
9✔
1112

1113
    # Generate points based on grid type
1114
    match gridtype:
9✔
1115
        case 'boxgrid' | None:
9✔
1116
            gridspec = _parse_args([6, 4])
9✔
1117
            points = boxgrid(
9✔
1118
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1119
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1120

1121
        case 'meshgrid':
9✔
1122
            gridspec = _parse_args([9, 6])
9✔
1123
            points = meshgrid(
9✔
1124
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
1125
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
1126

1127
        case 'circlegrid':
9✔
1128
            gridspec = _parse_args((0.5, 10))
9✔
1129
            if isinstance(pointdata, np.ndarray):
9✔
1130
                # Create circles around each point
1131
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
9✔
1132
            else:
1133
                # Create circle around center of the plot
1134
                points = circlegrid(
9✔
1135
                    np.array(
1136
                        [(pointdata[0] + pointdata[1]) / 2,
1137
                         (pointdata[0] + pointdata[1]) / 2]),
1138
                    gridspec[0], gridspec[1])
1139

1140
        case _:
9✔
1141
            raise ValueError(f"unknown grid type '{gridtype}'")
9✔
1142

1143
    return points, gridspec
9✔
1144

1145

1146
def _parse_arrow_keywords(kwargs):
9✔
1147
    # Get values for params (and pop from list to allow keyword use in plot)
1148
    # TODO: turn this into a utility function (shared with nyquist_plot?)
1149
    arrows = config._get_param(
9✔
1150
        'phaseplot', 'arrows', kwargs, None, pop=True)
1151
    arrow_size = config._get_param(
9✔
1152
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
1153
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
9✔
1154

1155
    # Parse the arrows keyword
1156
    if not arrows:
9✔
1157
        arrow_pos = []
×
1158
    elif isinstance(arrows, int):
9✔
1159
        N = arrows
9✔
1160
        # Space arrows out, starting midway along each "region"
1161
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
9✔
1162
    elif isinstance(arrows, (list, np.ndarray)):
×
1163
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
1164
    else:
1165
        raise ValueError("unknown or unsupported arrow location")
×
1166

1167
    # Set the arrow style
1168
    if arrow_style is None:
9✔
1169
        arrow_style = mpl.patches.ArrowStyle(
9✔
1170
            'simple', head_width=int(2 * arrow_size / 3),
1171
            head_length=arrow_size)
1172

1173
    return arrow_pos, arrow_style
9✔
1174

1175

1176
# TODO: move to ctrlplot?
1177
def _create_trajectory(
9✔
1178
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
1179
        gridtype=None, gridspec=None, xlim=None, ylim=None):
1180
    # Compute the forward trajectory
1181
    if dir == 'forward' or dir == 'both':
9✔
1182
        fwdresp = input_output_response(
9✔
1183
            sys, timepts, initial_state=X0, params=params, ignore_errors=True)
1184
        if not fwdresp.success and not suppress_warnings:
9✔
1185
            warnings.warn(f"initial_state={X0}, {fwdresp.message}")
9✔
1186

1187
    # Compute the reverse trajectory
1188
    if dir == 'reverse' or dir == 'both':
9✔
1189
        revresp = input_output_response(
9✔
1190
            revsys, timepts, initial_state=X0, params=params,
1191
            ignore_errors=True)
1192
        if not revresp.success and not suppress_warnings:
9✔
1193
            warnings.warn(f"initial_state={X0}, {revresp.message}")
×
1194

1195
    # Create the trace to plot
1196
    if dir == 'forward':
9✔
1197
        traj = fwdresp.states
9✔
1198
    elif dir == 'reverse':
9✔
1199
        traj = revresp.states[:, ::-1]
9✔
1200
    elif dir == 'both':
9✔
1201
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
9✔
1202

1203
    # Remove points outside the window (keep first point beyond boundary)
1204
    inrange = np.asarray(
9✔
1205
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
1206
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
1207
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
9✔
1208
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
9✔
1209

1210
    return traj[:, inrange]
9✔
1211

1212

1213
def _make_timepts(timepts, i):
9✔
1214
    if timepts is None:
9✔
1215
        return np.linspace(0, 1)
9✔
1216
    elif isinstance(timepts, (int, float)):
9✔
1217
        return np.linspace(0, timepts)
9✔
1218
    elif timepts.ndim == 2:
×
1219
        return timepts[i]
×
1220
    return timepts
×
1221

1222

1223
#
1224
# Legacy phase plot function
1225
#
1226
# Author: Richard Murray
1227
# Date: 24 July 2011, converted from MATLAB version (2002); based on
1228
# a version by Kristi Morgansen
1229
#
1230
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
9✔
1231
               lingrid=None, lintime=None, logtime=None, timepts=None,
1232
               parms=None, params=(), tfirst=False, verbose=True):
1233

1234
    """(legacy) Phase plot for 2D dynamical systems.
1235

1236
    .. deprecated:: 0.10.1
1237
        This function is deprecated; use `phase_plane_plot` instead.
1238

1239
    Produces a vector field or stream line plot for a planar system.  This
1240
    function has been replaced by the `phase_plane_map` and
1241
    `phase_plane_plot` functions.
1242

1243
    Call signatures:
1244
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
1245
      phase_plot(func, X, Y, scale, ...) - scale arrows
1246
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
1247
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
1248
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
1249
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
1250

1251
    Parameters
1252
    ----------
1253
    func : callable(x, t, ...)
1254
        Computes the time derivative of y (compatible with odeint).  The
1255
        function should be the same for as used for `scipy.integrate`.
1256
        Namely, it should be a function of the form dx/dt = F(t, x) that
1257
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1258
        dimension 2.
1259
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1260
        Two 3-element sequences specifying x and y coordinates of a
1261
        grid.  These arguments are passed to linspace and meshgrid to
1262
        generate the points at which the vector field is plotted.  If
1263
        absent (or None), the vector field is not plotted.
1264
    scale: float, optional
1265
        Scale size of arrows; default = 1
1266
    X0: ndarray of initial conditions, optional
1267
        List of initial conditions from which streamlines are plotted.
1268
        Each initial condition should be a pair of numbers.
1269
    T: array_like or number, optional
1270
        Length of time to run simulations that generate streamlines.
1271
        If a single number, the same simulation time is used for all
1272
        initial conditions.  Otherwise, should be a list of length
1273
        len(X0) that gives the simulation time for each initial
1274
        condition.  Default value = 50.
1275
    lingrid : integer or 2-tuple of integers, optional
1276
        Argument is either N or (N, M).  If X0 is given and X, Y are
1277
        missing, a grid of arrows is produced using the limits of the
1278
        initial conditions, with N grid points in each dimension or N grid
1279
        points in x and M grid points in y.
1280
    lintime : integer or tuple (integer, float), optional
1281
        If a single integer N is given, draw N arrows using equally space
1282
        time points.  If a tuple (N, lambda) is given, draw N arrows using
1283
        exponential time constant lambda
1284
    timepts : array_like, optional
1285
        Draw arrows at the given list times [t1, t2, ...]
1286
    tfirst : bool, optional
1287
        If True, call `func` with signature ``func(t, x, ...)``.
1288
    params: tuple, optional
1289
        List of parameters to pass to vector field: ``func(x, t, *params)``.
1290

1291
    See Also
1292
    --------
1293
    box_grid
1294

1295
    """
1296
    # Generate a deprecation warning
1297
    warnings.warn(
9✔
1298
        "phase_plot() is deprecated; use phase_plane_plot() instead",
1299
        FutureWarning)
1300

1301
    #
1302
    # Figure out ranges for phase plot (argument processing)
1303
    #
1304
    #! TODO: need to add error checking to arguments
1305
    #! TODO: think through proper action if multiple options are given
1306
    #
1307
    autoFlag = False
9✔
1308
    logtimeFlag = False
9✔
1309
    timeptsFlag = False
9✔
1310
    Narrows = 0
9✔
1311

1312
    # Get parameters to pass to function
1313
    if parms:
9✔
1314
        warnings.warn(
9✔
1315
            "keyword 'parms' is deprecated; use 'params'", FutureWarning)
1316
        if params:
9✔
1317
            raise ControlArgument("duplicate keywords 'parms' and 'params'")
×
1318
        else:
1319
            params = parms
9✔
1320

1321
    if lingrid is not None:
9✔
1322
        autoFlag = True
9✔
1323
        Narrows = lingrid
9✔
1324
        if (verbose):
9✔
1325
            print('Using auto arrows\n')
×
1326

1327
    elif logtime is not None:
9✔
1328
        logtimeFlag = True
9✔
1329
        Narrows = logtime[0]
9✔
1330
        timefactor = logtime[1]
9✔
1331
        if (verbose):
9✔
1332
            print('Using logtime arrows\n')
×
1333

1334
    elif timepts is not None:
9✔
1335
        timeptsFlag = True
9✔
1336
        Narrows = len(timepts)
9✔
1337

1338
    # Figure out the set of points for the quiver plot
1339
    #! TODO: Add sanity checks
1340
    elif X is not None and Y is not None:
9✔
1341
        x1, x2 = np.meshgrid(
9✔
1342
            np.linspace(X[0], X[1], X[2]),
1343
            np.linspace(Y[0], Y[1], Y[2]))
1344
        Narrows = len(x1)
9✔
1345

1346
    else:
1347
        # If we weren't given any grid points, don't plot arrows
1348
        Narrows = 0
9✔
1349

1350
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
9✔
1351
        # Now calculate the vector field at those points
1352
        (nr,nc) = x1.shape
9✔
1353
        dx = np.empty((nr, nc, 2))
9✔
1354
        for i in range(nr):
9✔
1355
            for j in range(nc):
9✔
1356
                if tfirst:
9✔
1357
                    dx[i, j, :] = np.squeeze(
×
1358
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1359
                else:
1360
                    dx[i, j, :] = np.squeeze(
9✔
1361
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1362

1363
        # Plot the quiver plot
1364
        #! TODO: figure out arguments to make arrows show up correctly
1365
        if scale is None:
9✔
1366
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1367
        elif (scale != 0):
9✔
1368
            plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
9✔
1369
                       dx[:,:,1]*np.abs(scale), angles='xy')
1370
            #! TODO: optimize parameters for arrows
1371
            #! TODO: figure out arguments to make arrows show up correctly
1372
            # xy = plt.quiver(...)
1373
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1374

1375
        #! TODO: Tweak the shape of the plot
1376
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1377
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1378
        plt.xlabel('x1'); plt.ylabel('x2')
9✔
1379

1380
    # See if we should also generate the streamlines
1381
    if X0 is None or len(X0) == 0:
9✔
1382
        return
9✔
1383

1384
    # Convert initial conditions to a numpy array
1385
    X0 = np.array(X0)
9✔
1386
    (nr, nc) = np.shape(X0)
9✔
1387

1388
    # Generate some empty matrices to keep arrow information
1389
    x1 = np.empty((nr, Narrows))
9✔
1390
    x2 = np.empty((nr, Narrows))
9✔
1391
    dx = np.empty((nr, Narrows, 2))
9✔
1392

1393
    # See if we were passed a simulation time
1394
    if T is None:
9✔
1395
        T = 50
9✔
1396

1397
    # Parse the time we were passed
1398
    TSPAN = T
9✔
1399
    if isinstance(T, (int, float)):
9✔
1400
        TSPAN = np.linspace(0, T, 100)
9✔
1401

1402
    # Figure out the limits for the plot
1403
    if scale is None:
9✔
1404
        # Assume that the current axis are set as we want them
1405
        alim = plt.axis()
×
1406
        xmin = alim[0]; xmax = alim[1]
×
1407
        ymin = alim[2]; ymax = alim[3]
×
1408
    else:
1409
        # Use the maximum extent of all trajectories
1410
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
9✔
1411
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
9✔
1412

1413
    # Generate the streamlines for each initial condition
1414
    for i in range(nr):
9✔
1415
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
9✔
1416
        time = TSPAN
9✔
1417

1418
        plt.plot(state[:,0], state[:,1])
9✔
1419
        #! TODO: add back in colors for stream lines
1420
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1421
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1422

1423
        # Plot arrows if quiver parameters were 'auto'
1424
        if autoFlag or logtimeFlag or timeptsFlag:
9✔
1425
            # Compute the locations of the arrows
1426
            #! TODO: check this logic to make sure it works in python
1427
            for j in range(Narrows):
9✔
1428

1429
                # Figure out starting index; headless arrows start at 0
1430
                k = -1 if scale is None else 0
9✔
1431

1432
                # Figure out what time index to use for the next point
1433
                if autoFlag:
9✔
1434
                    # Use a linear scaling based on ODE time vector
1435
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1436
                elif logtimeFlag:
9✔
1437
                    # Use an exponential time vector
1438
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1439
                    tarr = _find(time < (j-k) / timefactor)
9✔
1440
                    tind = tarr[-1] if len(tarr) else 0
9✔
1441
                elif timeptsFlag:
9✔
1442
                    # Use specified time points
1443
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1444
                    tarr = _find(time < timepts[j])
9✔
1445
                    tind = tarr[-1] if len(tarr) else 0
9✔
1446

1447
                # For tailless arrows, skip the first point
1448
                if tind == 0 and scale is None:
9✔
1449
                    continue
×
1450

1451
                # Figure out the arrow at this point on the curve
1452
                x1[i,j] = state[tind, 0]
9✔
1453
                x2[i,j] = state[tind, 1]
9✔
1454

1455
                # Skip arrows outside of initial condition box
1456
                if (scale is not None or
9✔
1457
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1458
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1459
                    if tfirst:
9✔
1460
                        pass
×
1461
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1462
                    else:
1463
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
9✔
1464
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
9✔
1465
                else:
1466
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1467

1468
    # Set the plot shape before plotting arrows to avoid warping
1469
    # a=gca
1470
    # if (scale != None):
1471
    #     set(a,'DataAspectRatio', [1,1,1])
1472
    # if (xmin != xmax and ymin != ymax):
1473
    #     plt.axis([xmin, xmax, ymin, ymax])
1474
    # set(a, 'Box', 'on')
1475

1476
    # Plot arrows on the streamlines
1477
    if scale is None and Narrows > 0:
9✔
1478
        # Use a tailless arrow
1479
        #! TODO: figure out arguments to make arrows show up correctly
1480
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1481
    elif scale != 0 and Narrows > 0:
9✔
1482
        plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
9✔
1483
                   angles='xy')
1484
        #! TODO: figure out arguments to make arrows show up correctly
1485
        # xy = plt.quiver(...)
1486
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1487
        # set(xy, 'AutoScale', 'off')
1488
        # set(xy, 'AutoScaleFactor', 0)
1489

1490
    if scale < 0:
9✔
1491
        plt.plot(x1, x2, 'b.');        # add dots at base
×
1492
        # bp = plt.plot(...)
1493
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1494

1495

1496
# Utility function for generating initial conditions around a box
1497
def box_grid(xlimp, ylimp):
9✔
1498
    """Generate list of points on edge of box.
1499

1500
    .. deprecated:: 0.10.0
1501
        Use `phaseplot.boxgrid` instead.
1502

1503
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1504
    list of points that correspond to a uniform grid at the end of the
1505
    box defined by the corners [xmin ymin] and [xmax ymax].
1506

1507
    """
1508

1509
    # Generate a deprecation warning
1510
    warnings.warn(
×
1511
        "box_grid() is deprecated; use phaseplot.boxgrid() instead",
1512
        FutureWarning)
1513

1514
    return boxgrid(
×
1515
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1516
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1517

1518

1519
# TODO: rename to something more useful (or remove??)
1520
def _find(condition):
9✔
1521
    """Returns indices where ravel(a) is true.
1522

1523
    Private implementation of deprecated `matplotlib.mlab.find`.
1524

1525
    """
1526
    return np.nonzero(np.ravel(condition))[0]
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc