• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 13222870129

09 Feb 2025 05:48AM UTC coverage: 94.715% (-0.04%) from 94.751%
13222870129

Pull #1125

github

web-flow
Merge fc9fadb37 into cf77f990b
Pull Request #1125: Uniform processing of time response and optimization parameters

9804 of 10351 relevant lines covered (94.72%)

8.28 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.42
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Initial author: Richard M. Murray
4
# Creation date: 24 July 2011, converted from MATLAB version (2002);
5
# based on an original version by Kristi Morgansen
6

7
"""Generate 2D phase portraits.
8

9
This module contains functions for generating 2D phase plots. The base
10
function for creating phase plane portraits is `~control.phase_plane_plot`,
11
which generates a phase plane portrait for a 2 state I/O system (with no
12
inputs). Utility functions are available to customize the individual
13
elements of a phase plane portrait.
14

15
The docstring examples assume the following import commands::
16

17
  >>> import numpy as np
18
  >>> import control as ct
19
  >>> import control.phaseplot as pp
20

21
"""
22

23
import math
9✔
24
import warnings
9✔
25

26
import matplotlib as mpl
9✔
27
import matplotlib.pyplot as plt
9✔
28
import numpy as np
9✔
29
from scipy.integrate import odeint
9✔
30

31
from . import config
9✔
32
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
9✔
33
    _process_ax_keyword, _update_plot_title
34
from .exception import ControlArgument
9✔
35
from .nlsys import NonlinearIOSystem, find_operating_point, \
9✔
36
    input_output_response
37

38
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
9✔
39

40
# Default values for module parameter variables
41
_phaseplot_defaults = {
9✔
42
    'phaseplot.arrows': 2,                  # number of arrows around curve
43
    'phaseplot.arrow_size': 8,              # pixel size for arrows
44
    'phaseplot.arrow_style': None,          # set arrow style
45
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
46
}
47

48
def phase_plane_plot(
9✔
49
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
50
        plot_streamlines=True, plot_vectorfield=False, plot_equilpoints=True,
51
        plot_separatrices=True, ax=None, suppress_warnings=False, title=None,
52
        **kwargs
53
):
54
    """Plot phase plane diagram.
55

56
    This function plots phase plane data, including vector fields, stream
57
    lines, equilibrium points, and contour curves.
58

59
    Parameters
60
    ----------
61
    sys : `NonlinearIOSystem` or callable(t, x, ...)
62
        I/O system or function used to generate phase plane data. If a
63
        function is given, the remaining arguments are drawn from the
64
        `params` keyword.
65
    pointdata : list or 2D array
66
        List of the form [xmin, xmax, ymin, ymax] describing the
67
        boundaries of the phase plot or an array of shape (N, 2)
68
        giving points of at which to plot the vector field.
69
    timedata : int or list of int
70
        Time to simulate each streamline.  If a list is given, a different
71
        time can be used for each initial condition in `pointdata`.
72
    gridtype : str, optional
73
        The type of grid to use for generating initial conditions:
74
        'meshgrid' (default) generates a mesh of initial conditions within
75
        the specified boundaries, 'boxgrid' generates initial conditions
76
        along the edges of the boundary, 'circlegrid' generates a circle of
77
        initial conditions around each point in point data.
78
    gridspec : list, optional
79
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
80
        size of the grid in the x and y axes on which to generate points.
81
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
82
        specifying the radius and number of points around each point in the
83
        `pointdata` array.
84
    params : dict, optional
85
        Parameters to pass to system. For an I/O system, `params` should be
86
        a dict of parameters and values. For a callable, `params` should be
87
        dict with key 'args' and value given by a tuple (passed to callable).
88
    color : matplotlib color spec, optional
89
        Plot all elements in the given color (use ``plot_<element>`` =
90
        {'color': c} to set the color in one element of the phase
91
        plot (equilpoints, separatrices, streamlines, etc).
92
    ax : `matplotlib.axes.Axes`, optional
93
        The matplotlib axes to draw the figure on.  If not specified and
94
        the current figure has a single axes, that axes is used.
95
        Otherwise, a new figure is created.
96

97
    Returns
98
    -------
99
    cplt : `ControlPlot` object
100
        Object containing the data that were plotted.  See `ControlPlot`
101
        for more detailed information.
102
    cplt.lines : array of list of `matplotlib.lines.Line2D`
103
        Array of list of `matplotlib.artist.Artist` objects:
104

105
            - lines[0] = list of Line2D objects (streamlines, separatrices).
106
            - lines[1] = Quiver object (vector field arrows).
107
            - lines[2] = list of Line2D objects (equilibrium points).
108

109
    cplt.axes : 2D array of `matplotlib.axes.Axes`
110
        Axes for each subplot.
111
    cplt.figure : `matplotlib.figure.Figure`
112
        Figure containing the plot.
113

114
    Other Parameters
115
    ----------------
116
    arrows : int
117
        Set the number of arrows to plot along the streamlines. The default
118
        value can be set in `config.defaults['phaseplot.arrows']`.
119
    arrow_size : float
120
        Set the size of arrows to plot along the streamlines.  The default
121
        value can be set in `config.defaults['phaseplot.arrow_size']`.
122
    arrow_style : matplotlib patch
123
        Set the style of arrows to plot along the streamlines.  The default
124
        value can be set in `config.defaults['phaseplot.arrow_style']`.
125
    dir : str, optional
126
        Direction to draw streamlines: 'forward' to flow forward in time
127
        from the reference points, 'reverse' to flow backward in time, or
128
        'both' to flow both forward and backward.  The amount of time to
129
        simulate in each direction is given by the `timedata` argument.
130
    plot_streamlines : bool or dict, optional
131
        If True (default) then plot streamlines based on the pointdata
132
        and gridtype.  If set to a dict, pass on the key-value pairs in
133
        the dict as keywords to `streamlines`.
134
    plot_vectorfield : bool or dict, optional
135
        If True (default) then plot the vector field based on the pointdata
136
        and gridtype.  If set to a dict, pass on the key-value pairs in
137
        the dict as keywords to `phaseplot.vectorfield`.
138
    plot_equilpoints : bool or dict, optional
139
        If True (default) then plot equilibrium points based in the phase
140
        plot boundary. If set to a dict, pass on the key-value pairs in the
141
        dict as keywords to `phaseplot.equilpoints`.
142
    plot_separatrices : bool or dict, optional
143
        If True (default) then plot separatrices starting from each
144
        equilibrium point.  If set to a dict, pass on the key-value pairs
145
        in the dict as keywords to `phaseplot.separatrices`.
146
    rcParams : dict
147
        Override the default parameters used for generating plots.
148
        Default is set by `config.defaults['ctrlplot.rcParams']`.
149
    suppress_warnings : bool, optional
150
        If set to True, suppress warning messages in generating trajectories.
151
    title : str, optional
152
        Set the title of the plot.  Defaults to plot type and system name(s).
153

154
    """
155
    # Process arguments
156
    params = kwargs.get('params', None)
9✔
157
    sys = _create_system(sys, params)
9✔
158
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
9✔
159
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
160

161
    # Create axis if needed
162
    user_ax = ax
9✔
163
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
9✔
164

165
    # Create copy of kwargs for later checking to find unused arguments
166
    initial_kwargs = dict(kwargs)
9✔
167

168
    # Utility function to create keyword arguments
169
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
9✔
170
        new_kwargs = dict(global_kwargs)
9✔
171
        new_kwargs.update(other_kwargs)
9✔
172
        if isinstance(local_kwargs, dict):
9✔
173
            new_kwargs.update(local_kwargs)
9✔
174
        return new_kwargs
9✔
175

176
    # Create list for storing outputs
177
    out = np.array([[], None, None], dtype=object)
9✔
178

179
    # Plot out the main elements
180
    if plot_streamlines:
9✔
181
        kwargs_local = _create_kwargs(
9✔
182
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
183
            ax=ax)
184
        out[0] += streamlines(
9✔
185
            sys, pointdata, timedata, _check_kwargs=False,
186
            suppress_warnings=suppress_warnings, **kwargs_local)
187

188
        # Get rid of keyword arguments handled by streamlines
189
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
9✔
190
                   'dir', 'params']:
191
            initial_kwargs.pop(kw, None)
9✔
192

193
    # Reset the gridspec for the remaining commands, if needed
194
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
9✔
195
        gridspec = None
×
196

197
    if plot_separatrices:
9✔
198
        kwargs_local = _create_kwargs(
9✔
199
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
200
        out[0] += separatrices(
9✔
201
            sys, pointdata, _check_kwargs=False, **kwargs_local)
202

203
        # Get rid of keyword arguments handled by separatrices
204
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
9✔
205
            initial_kwargs.pop(kw, None)
9✔
206

207
    if plot_vectorfield:
9✔
208
        kwargs_local = _create_kwargs(
×
209
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
210
        out[1] = vectorfield(
×
211
            sys, pointdata, _check_kwargs=False, **kwargs_local)
212

213
        # Get rid of keyword arguments handled by vectorfield
214
        for kw in ['color', 'params']:
×
215
            initial_kwargs.pop(kw, None)
×
216

217
    if plot_equilpoints:
9✔
218
        kwargs_local = _create_kwargs(
9✔
219
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
220
        out[2] = equilpoints(
9✔
221
            sys, pointdata, _check_kwargs=False, **kwargs_local)
222

223
        # Get rid of keyword arguments handled by equilpoints
224
        for kw in ['params']:
9✔
225
            initial_kwargs.pop(kw, None)
9✔
226

227
    # Make sure all keyword arguments were used
228
    if initial_kwargs:
9✔
229
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
9✔
230

231
    if user_ax is None:
9✔
232
        if title is None:
9✔
233
            title = f"Phase portrait for {sys.name}"
9✔
234
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
9✔
235
        ax.set_xlabel(sys.state_labels[0])
9✔
236
        ax.set_ylabel(sys.state_labels[1])
9✔
237
        plt.tight_layout()
9✔
238

239
    return ControlPlot(out, ax, fig)
9✔
240

241

242
def vectorfield(
9✔
243
        sys, pointdata, gridspec=None, ax=None, suppress_warnings=False,
244
        _check_kwargs=True, **kwargs):
245
    """Plot a vector field in the phase plane.
246

247
    This function plots a vector field for a two-dimensional state
248
    space system.
249

250
    Parameters
251
    ----------
252
    sys : `NonlinearIOSystem` or callable(t, x, ...)
253
        I/O system or function used to generate phase plane data.  If a
254
        function is given, the remaining arguments are drawn from the
255
        `params` keyword.
256
    pointdata : list or 2D array
257
        List of the form [xmin, xmax, ymin, ymax] describing the
258
        boundaries of the phase plot or an array of shape (N, 2)
259
        giving points of at which to plot the vector field.
260
    gridtype : str, optional
261
        The type of grid to use for generating initial conditions:
262
        'meshgrid' (default) generates a mesh of initial conditions within
263
        the specified boundaries, 'boxgrid' generates initial conditions
264
        along the edges of the boundary, 'circlegrid' generates a circle of
265
        initial conditions around each point in point data.
266
    gridspec : list, optional
267
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
268
        size of the grid in the x and y axes on which to generate points.
269
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
270
        specifying the radius and number of points around each point in the
271
        `pointdata` array.
272
    params : dict or list, optional
273
        Parameters to pass to system. For an I/O system, `params` should be
274
        a dict of parameters and values. For a callable, `params` should be
275
        dict with key 'args' and value given by a tuple (passed to callable).
276
    color : matplotlib color spec, optional
277
        Plot the vector field in the given color.
278
    ax : `matplotlib.axes.Axes`, optional
279
        Use the given axes for the plot, otherwise use the current axes.
280

281
    Returns
282
    -------
283
    out : Quiver
284

285
    Other Parameters
286
    ----------------
287
    rcParams : dict
288
        Override the default parameters used for generating plots.
289
        Default is set by `config.defaults['ctrlplot.rcParams']`.
290
    suppress_warnings : bool, optional
291
        If set to True, suppress warning messages in generating trajectories.
292

293
    """
294
    # Process keywords
295
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
296

297
    # Get system parameters
298
    params = kwargs.pop('params', None)
9✔
299

300
    # Create system from callable, if needed
301
    sys = _create_system(sys, params)
9✔
302

303
    # Determine the points on which to generate the vector field
304
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
305

306
    # Create axis if needed
307
    if ax is None:
9✔
308
        ax = plt.gca()
9✔
309

310
    # Set the plotting limits
311
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
312

313
    # Figure out the color to use
314
    color = _get_color(kwargs, ax=ax)
9✔
315

316
    # Make sure all keyword arguments were processed
317
    if _check_kwargs and kwargs:
9✔
318
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
319

320
    # Generate phase plane (quiver) data
321
    vfdata = np.zeros((points.shape[0], 4))
9✔
322
    sys._update_params(params)
9✔
323
    for i, x in enumerate(points):
9✔
324
        vfdata[i, :2] = x
9✔
325
        vfdata[i, 2:] = sys._rhs(0, x, np.zeros(sys.ninputs))
9✔
326

327
    with plt.rc_context(rcParams):
9✔
328
        out = ax.quiver(
9✔
329
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
330
            angles='xy', color=color)
331

332
    return out
9✔
333

334

335
def streamlines(
9✔
336
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
337
        ax=None, _check_kwargs=True, suppress_warnings=False, **kwargs):
338
    """Plot stream lines in the phase plane.
339

340
    This function plots stream lines for a two-dimensional state space
341
    system.
342

343
    Parameters
344
    ----------
345
    sys : `NonlinearIOSystem` or callable(t, x, ...)
346
        I/O system or function used to generate phase plane data.  If a
347
        function is given, the remaining arguments are drawn from the
348
        `params` keyword.
349
    pointdata : list or 2D array
350
        List of the form [xmin, xmax, ymin, ymax] describing the
351
        boundaries of the phase plot or an array of shape (N, 2)
352
        giving points of at which to plot the vector field.
353
    timedata : int or list of int
354
        Time to simulate each streamline.  If a list is given, a different
355
        time can be used for each initial condition in `pointdata`.
356
    gridtype : str, optional
357
        The type of grid to use for generating initial conditions:
358
        'meshgrid' (default) generates a mesh of initial conditions within
359
        the specified boundaries, 'boxgrid' generates initial conditions
360
        along the edges of the boundary, 'circlegrid' generates a circle of
361
        initial conditions around each point in point data.
362
    gridspec : list, optional
363
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
364
        size of the grid in the x and y axes on which to generate points.
365
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
366
        specifying the radius and number of points around each point in the
367
        `pointdata` array.
368
    dir : str, optional
369
        Direction to draw streamlines: 'forward' to flow forward in time
370
        from the reference points, 'reverse' to flow backward in time, or
371
        'both' to flow both forward and backward.  The amount of time to
372
        simulate in each direction is given by the `timedata` argument.
373
    params : dict or list, optional
374
        Parameters to pass to system. For an I/O system, `params` should be
375
        a dict of parameters and values. For a callable, `params` should be
376
        dict with key 'args' and value given by a tuple (passed to callable).
377
    color : str
378
        Plot the streamlines in the given color.
379
    ax : `matplotlib.axes.Axes`, optional
380
        Use the given axes for the plot, otherwise use the current axes.
381

382
    Returns
383
    -------
384
    out : list of Line2D objects
385

386
    Other Parameters
387
    ----------------
388
    arrows : int
389
        Set the number of arrows to plot along the streamlines. The default
390
        value can be set in `config.defaults['phaseplot.arrows']`.
391
    arrow_size : float
392
        Set the size of arrows to plot along the streamlines.  The default
393
        value can be set in `config.defaults['phaseplot.arrow_size']`.
394
    arrow_style : matplotlib patch
395
        Set the style of arrows to plot along the streamlines.  The default
396
        value can be set in `config.defaults['phaseplot.arrow_style']`.
397
    rcParams : dict
398
        Override the default parameters used for generating plots.
399
        Default is set by `config.defaults['ctrlplot.rcParams']`.
400
    suppress_warnings : bool, optional
401
        If set to True, suppress warning messages in generating trajectories.
402

403
    """
404
    # Process keywords
405
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
406

407
    # Get system parameters
408
    params = kwargs.pop('params', None)
9✔
409

410
    # Create system from callable, if needed
411
    sys = _create_system(sys, params)
9✔
412

413
    # Parse the arrows keyword
414
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
415

416
    # Determine the points on which to generate the streamlines
417
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
9✔
418
    if dir is None:
9✔
419
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
9✔
420

421
    # Create axis if needed
422
    if ax is None:
9✔
423
        ax = plt.gca()
9✔
424

425
    # Set the axis limits
426
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
427

428
    # Figure out the color to use
429
    color = _get_color(kwargs, ax=ax)
9✔
430

431
    # Make sure all keyword arguments were processed
432
    if _check_kwargs and kwargs:
9✔
433
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
434

435
    # Create reverse time system, if needed
436
    if dir != 'forward':
9✔
437
        revsys = NonlinearIOSystem(
9✔
438
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
439
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
440
            outputs=sys.noutputs, params=sys.params)
441
    else:
442
        revsys = None
9✔
443

444
    # Generate phase plane (streamline) data
445
    out = []
9✔
446
    for i, X0 in enumerate(points):
9✔
447
        # Create the trajectory for this point
448
        timepts = _make_timepts(timedata, i)
9✔
449
        traj = _create_trajectory(
9✔
450
            sys, revsys, timepts, X0, params, dir,
451
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
452
            suppress_warnings=suppress_warnings)
453

454
        # Plot the trajectory (if there is one)
455
        if traj.shape[1] > 1:
9✔
456
            with plt.rc_context(rcParams):
9✔
457
                out += ax.plot(traj[0], traj[1], color=color)
9✔
458

459
                # Add arrows to the lines at specified intervals
460
                _add_arrows_to_line2D(
9✔
461
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
462
    return out
9✔
463

464

465
def equilpoints(
9✔
466
        sys, pointdata, gridspec=None, color='k', ax=None,
467
        _check_kwargs=True, **kwargs):
468
    """Plot equilibrium points in the phase plane.
469

470
    This function plots the equilibrium points for a planar dynamical system.
471

472
    Parameters
473
    ----------
474
    sys : `NonlinearIOSystem` or callable(t, x, ...)
475
        I/O system or function used to generate phase plane data. If a
476
        function is given, the remaining arguments are drawn from the
477
        `params` keyword.
478
    pointdata : list or 2D array
479
        List of the form [xmin, xmax, ymin, ymax] describing the
480
        boundaries of the phase plot or an array of shape (N, 2)
481
        giving points of at which to plot the vector field.
482
    gridtype : str, optional
483
        The type of grid to use for generating initial conditions:
484
        'meshgrid' (default) generates a mesh of initial conditions within
485
        the specified boundaries, 'boxgrid' generates initial conditions
486
        along the edges of the boundary, 'circlegrid' generates a circle of
487
        initial conditions around each point in point data.
488
    gridspec : list, optional
489
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
490
        size of the grid in the x and y axes on which to generate points.
491
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
492
        specifying the radius and number of points around each point in the
493
        `pointdata` array.
494
    params : dict or list, optional
495
        Parameters to pass to system. For an I/O system, `params` should be
496
        a dict of parameters and values. For a callable, `params` should be
497
        dict with key 'args' and value given by a tuple (passed to callable).
498
    color : str
499
        Plot the equilibrium points in the given color.
500
    ax : `matplotlib.axes.Axes`, optional
501
        Use the given axes for the plot, otherwise use the current axes.
502

503
    Returns
504
    -------
505
    out : list of Line2D objects
506

507
    Other Parameters
508
    ----------------
509
    rcParams : dict
510
        Override the default parameters used for generating plots.
511
        Default is set by `config.defaults['ctrlplot.rcParams']`.
512

513
    """
514
    # Process keywords
515
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
516

517
    # Get system parameters
518
    params = kwargs.pop('params', None)
9✔
519

520
    # Create system from callable, if needed
521
    sys = _create_system(sys, params)
9✔
522

523
    # Create axis if needed
524
    if ax is None:
9✔
525
        ax = plt.gca()
9✔
526

527
    # Set the axis limits
528
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
529

530
    # Determine the points on which to generate the vector field
531
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
532
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
533

534
    # Make sure all keyword arguments were processed
535
    if _check_kwargs and kwargs:
9✔
536
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
537

538
    # Search for equilibrium points
539
    equilpts = _find_equilpts(sys, points, params=params)
9✔
540

541
    # Plot the equilibrium points
542
    out = []
9✔
543
    for xeq in equilpts:
9✔
544
        with plt.rc_context(rcParams):
9✔
545
            out += ax.plot(xeq[0], xeq[1], marker='o', color=color)
9✔
546
    return out
9✔
547

548

549
def separatrices(
9✔
550
        sys, pointdata, timedata=None, gridspec=None, ax=None,
551
        _check_kwargs=True, suppress_warnings=False, **kwargs):
552
    """Plot separatrices in the phase plane.
553

554
    This function plots separatrices for a two-dimensional state space
555
    system.
556

557
    Parameters
558
    ----------
559
    sys : `NonlinearIOSystem` or callable(t, x, ...)
560
        I/O system or function used to generate phase plane data. If a
561
        function is given, the remaining arguments are drawn from the
562
        `params` keyword.
563
    pointdata : list or 2D array
564
        List of the form [xmin, xmax, ymin, ymax] describing the
565
        boundaries of the phase plot or an array of shape (N, 2)
566
        giving points of at which to plot the vector field.
567
    timedata : int or list of int
568
        Time to simulate each streamline.  If a list is given, a different
569
        time can be used for each initial condition in `pointdata`.
570
    gridtype : str, optional
571
        The type of grid to use for generating initial conditions:
572
        'meshgrid' (default) generates a mesh of initial conditions within
573
        the specified boundaries, 'boxgrid' generates initial conditions
574
        along the edges of the boundary, 'circlegrid' generates a circle of
575
        initial conditions around each point in point data.
576
    gridspec : list, optional
577
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
578
        size of the grid in the x and y axes on which to generate points.
579
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
580
        specifying the radius and number of points around each point in the
581
        `pointdata` array.
582
    params : dict or list, optional
583
        Parameters to pass to system. For an I/O system, `params` should be
584
        a dict of parameters and values. For a callable, `params` should be
585
        dict with key 'args' and value given by a tuple (passed to callable).
586
    color : matplotlib color spec, optional
587
        Plot the separatrices in the given color.  If a single color
588
        specification is given, this is used for both stable and unstable
589
        separatrices.  If a tuple is given, the first element is used as
590
        the color specification for stable separatrices and the second
591
        element for unstable separatrices.
592
    ax : `matplotlib.axes.Axes`, optional
593
        Use the given axes for the plot, otherwise use the current axes.
594

595
    Returns
596
    -------
597
    out : list of Line2D objects
598

599
    Other Parameters
600
    ----------------
601
    rcParams : dict
602
        Override the default parameters used for generating plots.
603
        Default is set by `config.defaults['ctrlplot.rcParams']`.
604
    suppress_warnings : bool, optional
605
        If set to True, suppress warning messages in generating trajectories.
606

607
    Notes
608
    -----
609
    The value of `config.defaults['separatrices_radius']` is used to set the
610
    offset from the equilibrium point to the starting point of the separatix
611
    traces, in the direction of the eigenvectors evaluated at that
612
    equilibrium point.
613

614
    """
615
    # Process keywords
616
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
617

618
    # Get system parameters
619
    params = kwargs.pop('params', None)
9✔
620

621
    # Create system from callable, if needed
622
    sys = _create_system(sys, params)
9✔
623

624
    # Parse the arrows keyword
625
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
626

627
    # Determine the initial states to use in searching for equilibrium points
628
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
629
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
630

631
    # Find the equilibrium points
632
    equilpts = _find_equilpts(sys, points, params=params)
9✔
633
    radius = config._get_param('phaseplot', 'separatrices_radius')
9✔
634

635
    # Create axis if needed
636
    if ax is None:
9✔
637
        ax = plt.gca()
9✔
638

639
    # Set the axis limits
640
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
641

642
    # Figure out the color to use for stable, unstable subspaces
643
    color = _get_color(kwargs)
9✔
644
    match color:
9✔
645
        case None:
9✔
646
            stable_color = 'r'
9✔
647
            unstable_color = 'b'
9✔
648
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
9✔
649
            pass
9✔
650
        case single_color:
9✔
651
            stable_color = unstable_color = single_color
9✔
652

653
    # Make sure all keyword arguments were processed
654
    if _check_kwargs and kwargs:
9✔
655
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
656

657
    # Create a "reverse time" system to use for simulation
658
    revsys = NonlinearIOSystem(
9✔
659
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
660
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
661
        outputs=sys.noutputs, params=sys.params)
662

663
    # Plot separatrices by flowing backwards in time along eigenspaces
664
    out = []
9✔
665
    for i, xeq in enumerate(equilpts):
9✔
666
        # Plot the equilibrium points
667
        with plt.rc_context(rcParams):
9✔
668
            out += ax.plot(xeq[0], xeq[1], marker='o', color='k')
9✔
669

670
        # Figure out the linearization and eigenvectors
671
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
9✔
672

673
        # See if we have real eigenvalues (=> evecs are meaningful)
674
        if evals[0].imag > 0:
9✔
675
            continue
9✔
676

677
        # Create default list of time points
678
        if timedata is not None:
9✔
679
            timepts = _make_timepts(timedata, i)
9✔
680

681
        # Generate the traces
682
        for j, dir in enumerate(evecs.T):
9✔
683
            # Figure out time vector if not yet computed
684
            if timedata is None:
9✔
685
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
9✔
686
                timepts = np.linspace(0, timescale)
9✔
687

688
            # Run the trajectory starting in eigenvector directions
689
            for eps in [-radius, radius]:
9✔
690
                x0 = xeq + dir * eps
9✔
691
                if evals[j].real < 0:
9✔
692
                    traj = _create_trajectory(
9✔
693
                        sys, revsys, timepts, x0, params, 'reverse',
694
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
695
                        suppress_warnings=suppress_warnings)
696
                    color = stable_color
9✔
697
                    linestyle = '--'
9✔
698
                elif evals[j].real > 0:
9✔
699
                    traj = _create_trajectory(
9✔
700
                        sys, revsys, timepts, x0, params, 'forward',
701
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
702
                        suppress_warnings=suppress_warnings)
703
                    color = unstable_color
9✔
704
                    linestyle = '-'
9✔
705

706
                # Plot the trajectory (if there is one)
707
                if traj.shape[1] > 1:
9✔
708
                    with plt.rc_context(rcParams):
9✔
709
                        out += ax.plot(
9✔
710
                            traj[0], traj[1], color=color, linestyle=linestyle)
711

712
                    # Add arrows to the lines at specified intervals
713
                    with plt.rc_context(rcParams):
9✔
714
                        _add_arrows_to_line2D(
9✔
715
                            ax, out[-1], arrow_pos, arrowstyle=arrow_style,
716
                            dir=1)
717
    return out
9✔
718

719

720
#
721
# User accessible utility functions
722
#
723

724
# Utility function to generate boxgrid (in the form needed here)
725
def boxgrid(xvals, yvals):
9✔
726
    """Generate list of points along the edge of box.
727

728
    points = boxgrid(xvals, yvals) generates a list of points that
729
    corresponds to a grid given by the cross product of the x and y values.
730

731
    Parameters
732
    ----------
733
    xvals, yvals : 1D array_like
734
        Array of points defining the points on the lower and left edges of
735
        the box.
736

737
    Returns
738
    -------
739
    grid : 2D array
740
        Array with shape (p, 2) defining the points along the edges of the
741
        box, where p is the number of points around the edge.
742

743
    """
744
    return np.array(
9✔
745
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
746
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
747
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
748
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
749
    )
750

751

752
# Utility function to generate meshgrid (in the form needed here)
753
# TODO: add examples of using grid functions directly
754
def meshgrid(xvals, yvals):
9✔
755
    """Generate list of points forming a mesh.
756

757
    points = meshgrid(xvals, yvals) generates a list of points that
758
    corresponds to a grid given by the cross product of the x and y values.
759

760
    Parameters
761
    ----------
762
    xvals, yvals : 1D array_like
763
        Array of points defining the points on the lower and left edges of
764
        the box.
765

766
    Returns
767
    -------
768
    grid : 2D array
769
        Array of points with shape (n * m, 2) defining the mesh.
770

771
    """
772
    xvals, yvals = np.meshgrid(xvals, yvals)
9✔
773
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
9✔
774
    grid[:, 0] = xvals.reshape(-1)
9✔
775
    grid[:, 1] = yvals.reshape(-1)
9✔
776

777
    return grid
9✔
778

779

780
# Utility function to generate circular grid
781
def circlegrid(centers, radius, num):
9✔
782
    """Generate list of points around a circle.
783

784
    points = circlegrid(centers, radius, num) generates a list of points
785
    that form a circle around a list of centers.
786

787
    Parameters
788
    ----------
789
    centers : 2D array_like
790
        Array of points with shape (p, 2) defining centers of the circles.
791
    radius : float
792
        Radius of the points to be generated around each center.
793
    num : int
794
        Number of points to generate around the circle.
795

796
    Returns
797
    -------
798
    grid : 2D array
799
        Array of points with shape (p * num, 2) defining the circles.
800

801
    """
802
    centers = np.atleast_2d(np.array(centers))
9✔
803
    grid = np.zeros((centers.shape[0] * num, 2))
9✔
804
    for i, center in enumerate(centers):
9✔
805
        grid[i * num: (i + 1) * num, :] = center + np.array([
9✔
806
            [radius * math.cos(theta), radius * math.sin(theta)] for
807
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
808
    return grid
9✔
809

810
#
811
# Internal utility functions
812
#
813

814
# Create a system from a callable
815
def _create_system(sys, params):
9✔
816
    if isinstance(sys, NonlinearIOSystem):
9✔
817
        if sys.nstates != 2:
9✔
818
            raise ValueError("system must be planar")
9✔
819
        return sys
9✔
820

821
    # Make sure that if params is present, it has 'args' key
822
    if params and not params.get('args', None):
9✔
823
        raise ValueError("params must be dict with key 'args'")
9✔
824

825
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
9✔
826
    _output = lambda t, x, u, params: np.array([])
9✔
827
    return NonlinearIOSystem(
9✔
828
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
829

830
# Set axis limits for the plot
831
def _set_axis_limits(ax, pointdata):
9✔
832
    # Get the current axis limits
833
    if ax.lines:
9✔
834
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
9✔
835
    else:
836
        # Nothing on the plot => always use new limits
837
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
9✔
838

839
    # Short utility function for updating axis limits
840
    def _update_limits(cur, new):
9✔
841
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
9✔
842

843
    # If we were passed a box, use that to update the limits
844
    if isinstance(pointdata, list) and len(pointdata) == 4:
9✔
845
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
9✔
846
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
9✔
847

848
    elif isinstance(pointdata, np.ndarray):
9✔
849
        pointdata = np.atleast_2d(pointdata)
9✔
850
        xlim = _update_limits(
9✔
851
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
852
        ylim = _update_limits(
9✔
853
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
854

855
    # Keep track of the largest dimension on the plot
856
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
9✔
857

858
    # Set the new limits
859
    ax.autoscale(enable=True, axis='x', tight=True)
9✔
860
    ax.autoscale(enable=True, axis='y', tight=True)
9✔
861
    ax.set_xlim(xlim)
9✔
862
    ax.set_ylim(ylim)
9✔
863

864
    return xlim, ylim, maxlim
9✔
865

866

867
# Find equilibrium points
868
def _find_equilpts(sys, points, params=None):
9✔
869
    equilpts = []
9✔
870
    for i, x0 in enumerate(points):
9✔
871
        # Look for an equilibrium point near this point
872
        xeq, ueq = find_operating_point(sys, x0, 0, params=params)
9✔
873

874
        if xeq is None:
9✔
875
            continue            # didn't find anything
9✔
876

877
        # See if we have already found this point
878
        seen = False
9✔
879
        for x in equilpts:
9✔
880
            if np.allclose(np.array(x), xeq):
9✔
881
                seen = True
9✔
882
        if seen:
9✔
883
            continue
9✔
884

885
        # Save a new point
886
        equilpts += [xeq.tolist()]
9✔
887

888
    return equilpts
9✔
889

890

891
def _make_points(pointdata, gridspec, gridtype):
9✔
892
    # Check to see what type of data we got
893
    if isinstance(pointdata, np.ndarray) and gridtype is None:
9✔
894
        pointdata = np.atleast_2d(pointdata)
9✔
895
        if pointdata.shape[1] == 2:
9✔
896
            # Given a list of points => no action required
897
            return pointdata, None
9✔
898

899
    # Utility function to parse (and check) input arguments
900
    def _parse_args(defsize):
9✔
901
        if gridspec is None:
9✔
902
            return defsize
9✔
903

904
        elif not isinstance(gridspec, (list, tuple)) or \
9✔
905
             len(gridspec) != len(defsize):
906
            raise ValueError("invalid grid specification")
9✔
907

908
        return gridspec
9✔
909

910
    # Generate points based on grid type
911
    match gridtype:
9✔
912
        case 'boxgrid' | None:
9✔
913
            gridspec = _parse_args([6, 4])
9✔
914
            points = boxgrid(
9✔
915
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
916
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
917

918
        case 'meshgrid':
9✔
919
            gridspec = _parse_args([9, 6])
9✔
920
            points = meshgrid(
9✔
921
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
922
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
923

924
        case 'circlegrid':
9✔
925
            gridspec = _parse_args((0.5, 10))
9✔
926
            if isinstance(pointdata, np.ndarray):
9✔
927
                # Create circles around each point
928
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
9✔
929
            else:
930
                # Create circle around center of the plot
931
                points = circlegrid(
9✔
932
                    np.array(
933
                        [(pointdata[0] + pointdata[1]) / 2,
934
                         (pointdata[0] + pointdata[1]) / 2]),
935
                    gridspec[0], gridspec[1])
936

937
        case _:
9✔
938
            raise ValueError(f"unknown grid type '{gridtype}'")
9✔
939

940
    return points, gridspec
9✔
941

942

943
def _parse_arrow_keywords(kwargs):
9✔
944
    # Get values for params (and pop from list to allow keyword use in plot)
945
    # TODO: turn this into a utility function (shared with nyquist_plot?)
946
    arrows = config._get_param(
9✔
947
        'phaseplot', 'arrows', kwargs, None, pop=True)
948
    arrow_size = config._get_param(
9✔
949
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
950
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
9✔
951

952
    # Parse the arrows keyword
953
    if not arrows:
9✔
954
        arrow_pos = []
×
955
    elif isinstance(arrows, int):
9✔
956
        N = arrows
9✔
957
        # Space arrows out, starting midway along each "region"
958
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
9✔
959
    elif isinstance(arrows, (list, np.ndarray)):
×
960
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
961
    else:
962
        raise ValueError("unknown or unsupported arrow location")
×
963

964
    # Set the arrow style
965
    if arrow_style is None:
9✔
966
        arrow_style = mpl.patches.ArrowStyle(
9✔
967
            'simple', head_width=int(2 * arrow_size / 3),
968
            head_length=arrow_size)
969

970
    return arrow_pos, arrow_style
9✔
971

972

973
# TODO: move to ctrlplot?
974
def _create_trajectory(
9✔
975
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
976
        gridtype=None, gridspec=None, xlim=None, ylim=None):
977
    # Compute the forward trajectory
978
    if dir == 'forward' or dir == 'both':
9✔
979
        fwdresp = input_output_response(
9✔
980
            sys, timepts, initial_state=X0, params=params, ignore_errors=True)
981
        if not fwdresp.success and not suppress_warnings:
9✔
982
            warnings.warn(f"initial_state={X0}, {fwdresp.message}")
9✔
983

984
    # Compute the reverse trajectory
985
    if dir == 'reverse' or dir == 'both':
9✔
986
        revresp = input_output_response(
9✔
987
            revsys, timepts, initial_state=X0, params=params,
988
            ignore_errors=True)
989
        if not revresp.success and not suppress_warnings:
9✔
990
            warnings.warn(f"initial_state={X0}, {revresp.message}")
×
991

992
    # Create the trace to plot
993
    if dir == 'forward':
9✔
994
        traj = fwdresp.states
9✔
995
    elif dir == 'reverse':
9✔
996
        traj = revresp.states[:, ::-1]
9✔
997
    elif dir == 'both':
9✔
998
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
9✔
999

1000
    # Remove points outside the window (keep first point beyond boundary)
1001
    inrange = np.asarray(
9✔
1002
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
1003
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
1004
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
9✔
1005
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
9✔
1006

1007
    return traj[:, inrange]
9✔
1008

1009

1010
def _make_timepts(timepts, i):
9✔
1011
    if timepts is None:
9✔
1012
        return np.linspace(0, 1)
9✔
1013
    elif isinstance(timepts, (int, float)):
9✔
1014
        return np.linspace(0, timepts)
9✔
1015
    elif timepts.ndim == 2:
×
1016
        return timepts[i]
×
1017
    return timepts
×
1018

1019

1020
#
1021
# Legacy phase plot function
1022
#
1023
# Author: Richard Murray
1024
# Date: 24 July 2011, converted from MATLAB version (2002); based on
1025
# a version by Kristi Morgansen
1026
#
1027
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
9✔
1028
               lingrid=None, lintime=None, logtime=None, timepts=None,
1029
               parms=None, params=(), tfirst=False, verbose=True):
1030

1031
    """(legacy) Phase plot for 2D dynamical systems.
1032

1033
    .. deprecated:: 0.10.1
1034
        This function is deprecated; use `phase_plane_plot` instead.
1035

1036
    Produces a vector field or stream line plot for a planar system.  This
1037
    function has been replaced by the `phase_plane_map` and
1038
    `phase_plane_plot` functions.
1039

1040
    Call signatures:
1041
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
1042
      phase_plot(func, X, Y, scale, ...) - scale arrows
1043
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
1044
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
1045
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
1046
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
1047

1048
    Parameters
1049
    ----------
1050
    func : callable(x, t, ...)
1051
        Computes the time derivative of y (compatible with odeint).  The
1052
        function should be the same for as used for `scipy.integrate`.
1053
        Namely, it should be a function of the form dx/dt = F(t, x) that
1054
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1055
        dimension 2.
1056
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1057
        Two 3-element sequences specifying x and y coordinates of a
1058
        grid.  These arguments are passed to linspace and meshgrid to
1059
        generate the points at which the vector field is plotted.  If
1060
        absent (or None), the vector field is not plotted.
1061
    scale: float, optional
1062
        Scale size of arrows; default = 1
1063
    X0: ndarray of initial conditions, optional
1064
        List of initial conditions from which streamlines are plotted.
1065
        Each initial condition should be a pair of numbers.
1066
    T: array_like or number, optional
1067
        Length of time to run simulations that generate streamlines.
1068
        If a single number, the same simulation time is used for all
1069
        initial conditions.  Otherwise, should be a list of length
1070
        len(X0) that gives the simulation time for each initial
1071
        condition.  Default value = 50.
1072
    lingrid : integer or 2-tuple of integers, optional
1073
        Argument is either N or (N, M).  If X0 is given and X, Y are
1074
        missing, a grid of arrows is produced using the limits of the
1075
        initial conditions, with N grid points in each dimension or N grid
1076
        points in x and M grid points in y.
1077
    lintime : integer or tuple (integer, float), optional
1078
        If a single integer N is given, draw N arrows using equally space
1079
        time points.  If a tuple (N, lambda) is given, draw N arrows using
1080
        exponential time constant lambda
1081
    timepts : array_like, optional
1082
        Draw arrows at the given list times [t1, t2, ...]
1083
    tfirst : bool, optional
1084
        If True, call `func` with signature ``func(t, x, ...)``.
1085
    params: tuple, optional
1086
        List of parameters to pass to vector field: ``func(x, t, *params)``.
1087

1088
    See Also
1089
    --------
1090
    box_grid
1091

1092
    """
1093
    # Generate a deprecation warning
1094
    warnings.warn(
9✔
1095
        "phase_plot() is deprecated; use phase_plane_plot() instead",
1096
        FutureWarning)
1097

1098
    #
1099
    # Figure out ranges for phase plot (argument processing)
1100
    #
1101
    #! TODO: need to add error checking to arguments
1102
    #! TODO: think through proper action if multiple options are given
1103
    #
1104
    autoFlag = False
9✔
1105
    logtimeFlag = False
9✔
1106
    timeptsFlag = False
9✔
1107
    Narrows = 0
9✔
1108

1109
    # Get parameters to pass to function
1110
    if parms:
9✔
1111
        warnings.warn(
9✔
1112
            "keyword 'parms' is deprecated; use 'params'", FutureWarning)
1113
        if params:
9✔
1114
            raise ControlArgument("duplicate keywords 'parms' and 'params'")
×
1115
        else:
1116
            params = parms
9✔
1117

1118
    if lingrid is not None:
9✔
1119
        autoFlag = True
9✔
1120
        Narrows = lingrid
9✔
1121
        if (verbose):
9✔
1122
            print('Using auto arrows\n')
×
1123

1124
    elif logtime is not None:
9✔
1125
        logtimeFlag = True
9✔
1126
        Narrows = logtime[0]
9✔
1127
        timefactor = logtime[1]
9✔
1128
        if (verbose):
9✔
1129
            print('Using logtime arrows\n')
×
1130

1131
    elif timepts is not None:
9✔
1132
        timeptsFlag = True
9✔
1133
        Narrows = len(timepts)
9✔
1134

1135
    # Figure out the set of points for the quiver plot
1136
    #! TODO: Add sanity checks
1137
    elif X is not None and Y is not None:
9✔
1138
        x1, x2 = np.meshgrid(
9✔
1139
            np.linspace(X[0], X[1], X[2]),
1140
            np.linspace(Y[0], Y[1], Y[2]))
1141
        Narrows = len(x1)
9✔
1142

1143
    else:
1144
        # If we weren't given any grid points, don't plot arrows
1145
        Narrows = 0
9✔
1146

1147
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
9✔
1148
        # Now calculate the vector field at those points
1149
        (nr,nc) = x1.shape
9✔
1150
        dx = np.empty((nr, nc, 2))
9✔
1151
        for i in range(nr):
9✔
1152
            for j in range(nc):
9✔
1153
                if tfirst:
9✔
1154
                    dx[i, j, :] = np.squeeze(
×
1155
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1156
                else:
1157
                    dx[i, j, :] = np.squeeze(
9✔
1158
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1159

1160
        # Plot the quiver plot
1161
        #! TODO: figure out arguments to make arrows show up correctly
1162
        if scale is None:
9✔
1163
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1164
        elif (scale != 0):
9✔
1165
            plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
9✔
1166
                       dx[:,:,1]*np.abs(scale), angles='xy')
1167
            #! TODO: optimize parameters for arrows
1168
            #! TODO: figure out arguments to make arrows show up correctly
1169
            # xy = plt.quiver(...)
1170
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1171

1172
        #! TODO: Tweak the shape of the plot
1173
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1174
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1175
        plt.xlabel('x1'); plt.ylabel('x2')
9✔
1176

1177
    # See if we should also generate the streamlines
1178
    if X0 is None or len(X0) == 0:
9✔
1179
        return
9✔
1180

1181
    # Convert initial conditions to a numpy array
1182
    X0 = np.array(X0)
9✔
1183
    (nr, nc) = np.shape(X0)
9✔
1184

1185
    # Generate some empty matrices to keep arrow information
1186
    x1 = np.empty((nr, Narrows))
9✔
1187
    x2 = np.empty((nr, Narrows))
9✔
1188
    dx = np.empty((nr, Narrows, 2))
9✔
1189

1190
    # See if we were passed a simulation time
1191
    if T is None:
9✔
1192
        T = 50
9✔
1193

1194
    # Parse the time we were passed
1195
    TSPAN = T
9✔
1196
    if isinstance(T, (int, float)):
9✔
1197
        TSPAN = np.linspace(0, T, 100)
9✔
1198

1199
    # Figure out the limits for the plot
1200
    if scale is None:
9✔
1201
        # Assume that the current axis are set as we want them
1202
        alim = plt.axis()
×
1203
        xmin = alim[0]; xmax = alim[1]
×
1204
        ymin = alim[2]; ymax = alim[3]
×
1205
    else:
1206
        # Use the maximum extent of all trajectories
1207
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
9✔
1208
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
9✔
1209

1210
    # Generate the streamlines for each initial condition
1211
    for i in range(nr):
9✔
1212
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
9✔
1213
        time = TSPAN
9✔
1214

1215
        plt.plot(state[:,0], state[:,1])
9✔
1216
        #! TODO: add back in colors for stream lines
1217
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1218
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1219

1220
        # Plot arrows if quiver parameters were 'auto'
1221
        if autoFlag or logtimeFlag or timeptsFlag:
9✔
1222
            # Compute the locations of the arrows
1223
            #! TODO: check this logic to make sure it works in python
1224
            for j in range(Narrows):
9✔
1225

1226
                # Figure out starting index; headless arrows start at 0
1227
                k = -1 if scale is None else 0
9✔
1228

1229
                # Figure out what time index to use for the next point
1230
                if autoFlag:
9✔
1231
                    # Use a linear scaling based on ODE time vector
1232
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1233
                elif logtimeFlag:
9✔
1234
                    # Use an exponential time vector
1235
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1236
                    tarr = _find(time < (j-k) / timefactor)
9✔
1237
                    tind = tarr[-1] if len(tarr) else 0
9✔
1238
                elif timeptsFlag:
9✔
1239
                    # Use specified time points
1240
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1241
                    tarr = _find(time < timepts[j])
9✔
1242
                    tind = tarr[-1] if len(tarr) else 0
9✔
1243

1244
                # For tailless arrows, skip the first point
1245
                if tind == 0 and scale is None:
9✔
1246
                    continue
×
1247

1248
                # Figure out the arrow at this point on the curve
1249
                x1[i,j] = state[tind, 0]
9✔
1250
                x2[i,j] = state[tind, 1]
9✔
1251

1252
                # Skip arrows outside of initial condition box
1253
                if (scale is not None or
9✔
1254
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1255
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1256
                    if tfirst:
9✔
1257
                        pass
×
1258
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1259
                    else:
1260
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
9✔
1261
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
9✔
1262
                else:
1263
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1264

1265
    # Set the plot shape before plotting arrows to avoid warping
1266
    # a=gca
1267
    # if (scale != None):
1268
    #     set(a,'DataAspectRatio', [1,1,1])
1269
    # if (xmin != xmax and ymin != ymax):
1270
    #     plt.axis([xmin, xmax, ymin, ymax])
1271
    # set(a, 'Box', 'on')
1272

1273
    # Plot arrows on the streamlines
1274
    if scale is None and Narrows > 0:
9✔
1275
        # Use a tailless arrow
1276
        #! TODO: figure out arguments to make arrows show up correctly
1277
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1278
    elif scale != 0 and Narrows > 0:
9✔
1279
        plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
9✔
1280
                   angles='xy')
1281
        #! TODO: figure out arguments to make arrows show up correctly
1282
        # xy = plt.quiver(...)
1283
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1284
        # set(xy, 'AutoScale', 'off')
1285
        # set(xy, 'AutoScaleFactor', 0)
1286

1287
    if scale < 0:
9✔
1288
        plt.plot(x1, x2, 'b.');        # add dots at base
×
1289
        # bp = plt.plot(...)
1290
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1291

1292

1293
# Utility function for generating initial conditions around a box
1294
def box_grid(xlimp, ylimp):
9✔
1295
    """Generate list of points on edge of box.
1296

1297
    .. deprecated:: 0.10.0
1298
        Use `phaseplot.boxgrid` instead.
1299

1300
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1301
    list of points that correspond to a uniform grid at the end of the
1302
    box defined by the corners [xmin ymin] and [xmax ymax].
1303

1304
    """
1305

1306
    # Generate a deprecation warning
1307
    warnings.warn(
×
1308
        "box_grid() is deprecated; use phaseplot.boxgrid() instead",
1309
        FutureWarning)
1310

1311
    return boxgrid(
×
1312
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1313
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1314

1315

1316
# TODO: rename to something more useful (or remove??)
1317
def _find(condition):
9✔
1318
    """Returns indices where ravel(a) is true.
1319

1320
    Private implementation of deprecated `matplotlib.mlab.find`.
1321

1322
    """
1323
    return np.nonzero(np.ravel(condition))[0]
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc