• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 10370763703

13 Aug 2024 01:32PM UTC coverage: 94.693% (-0.001%) from 94.694%
10370763703

push

github

web-flow
Merge pull request #1038 from murrayrm/doc-comment_fixes-11May2024

Documentation updates and docstring unit tests

9136 of 9648 relevant lines covered (94.69%)

8.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.44
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Author: Richard M. Murray
4
# Date:   23 Mar 2024 (legacy version information below)
5
#
6
# TODO
7
# * Allow multiple timepoints (and change timespec name to T?)
8
# * Update linestyles (color -> linestyle?)
9
# * Check for keyword compatibility with other plot routines
10
# * Set up configuration parameters (nyquist --> phaseplot)
11

12
"""Module for generating 2D phase plane plots.
13

14
The :mod:`control.phaseplot` module contains functions for generating 2D
15
phase plots. The base function for creating phase plane portraits is
16
:func:`~control.phase_plane_plot`, which generates a phase plane portrait
17
for a 2 state I/O system (with no inputs).  In addition, several other
18
functions are available to create customized phase plane plots:
19

20
* boxgrid: Generate a list of points along the edge of a box
21
* circlegrid: Generate list of points around a circle
22
* equilpoints: Plot equilibrium points in the phase plane
23
* meshgrid: Generate a list of points forming a mesh
24
* separatrices: Plot separatrices in the phase plane
25
* streamlines: Plot stream lines in the phase plane
26
* vectorfield: Plot a vector field in the phase plane
27

28
"""
29

30
import math
9✔
31
import warnings
9✔
32

33
import matplotlib as mpl
9✔
34
import matplotlib.pyplot as plt
9✔
35
import numpy as np
9✔
36
from scipy.integrate import odeint
9✔
37

38
from . import config
9✔
39
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
9✔
40
    _process_ax_keyword, _update_plot_title
41
from .exception import ControlNotImplemented
9✔
42
from .nlsys import NonlinearIOSystem, find_eqpt, input_output_response
9✔
43

44
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
9✔
45

46
# Default values for module parameter variables
47
_phaseplot_defaults = {
9✔
48
    'phaseplot.arrows': 2,                  # number of arrows around curve
49
    'phaseplot.arrow_size': 8,              # pixel size for arrows
50
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
51
}
52

53
def phase_plane_plot(
9✔
54
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
55
        plot_streamlines=True, plot_vectorfield=False, plot_equilpoints=True,
56
        plot_separatrices=True, ax=None, suppress_warnings=False, title=None,
57
        **kwargs
58
):
59
    """Plot phase plane diagram.
60

61
    This function plots phase plane data, including vector fields, stream
62
    lines, equilibrium points, and contour curves.
63

64
    Parameters
65
    ----------
66
    sys : NonlinearIOSystem or callable(t, x, ...)
67
        I/O system or function used to generate phase plane data. If a
68
        function is given, the remaining arguments are drawn from the
69
        `params` keyword.
70
    pointdata : list or 2D array
71
        List of the form [xmin, xmax, ymin, ymax] describing the
72
        boundaries of the phase plot or an array of shape (N, 2)
73
        giving points of at which to plot the vector field.
74
    timedata : int or list of int
75
        Time to simulate each streamline.  If a list is given, a different
76
        time can be used for each initial condition in `pointdata`.
77
    gridtype : str, optional
78
        The type of grid to use for generating initial conditions:
79
        'meshgrid' (default) generates a mesh of initial conditions within
80
        the specified boundaries, 'boxgrid' generates initial conditions
81
        along the edges of the boundary, 'circlegrid' generates a circle of
82
        initial conditions around each point in point data.
83
    gridspec : list, optional
84
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
85
        size of the grid in the x and y axes on which to generate points.
86
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
87
        specifying the radius and number of points around each point in the
88
        `pointdata` array.
89
    params : dict, optional
90
        Parameters to pass to system. For an I/O system, `params` should be
91
        a dict of parameters and values. For a callable, `params` should be
92
        dict with key 'args' and value given by a tuple (passed to callable).
93
    color : matplotlib color spec, optional
94
        Plot all elements in the given color (use `plot_<fcn>={'color': c}`
95
        to set the color in one element of the phase plot.
96
    ax : matplotlib.axes.Axes, optional
97
        The matplotlib axes to draw the figure on.  If not specified and
98
        the current figure has a single axes, that axes is used.
99
        Otherwise, a new figure is created.
100

101
    Returns
102
    -------
103
    cplt : :class:`ControlPlot` object
104
        Object containing the data that were plotted:
105

106
          * cplt.lines: array of list of :class:`matplotlib.artist.Artist`
107
            objects:
108

109
              - lines[0] = list of Line2D objects (streamlines, separatrices).
110
              - lines[1] = Quiver object (vector field arrows).
111
              - lines[2] = list of Line2D objects (equilibrium points).
112

113
          * cplt.axes: 2D array of :class:`matplotlib.axes.Axes` for the plot.
114

115
          * cplt.figure: :class:`matplotlib.figure.Figure` containing the plot.
116

117
        See :class:`ControlPlot` for more detailed information.
118

119

120
    Other parameters
121
    ----------------
122
    dir : str, optional
123
        Direction to draw streamlines: 'forward' to flow forward in time
124
        from the reference points, 'reverse' to flow backward in time, or
125
        'both' to flow both forward and backward.  The amount of time to
126
        simulate in each direction is given by the ``timedata`` argument.
127
    plot_streamlines : bool or dict, optional
128
        If `True` (default) then plot streamlines based on the pointdata
129
        and gridtype.  If set to a dict, pass on the key-value pairs in
130
        the dict as keywords to :func:`~control.phaseplot.streamlines`.
131
    plot_vectorfield : bool or dict, optional
132
        If `True` (default) then plot the vector field based on the pointdata
133
        and gridtype.  If set to a dict, pass on the key-value pairs in
134
        the dict as keywords to :func:`~control.phaseplot.vectorfield`.
135
    plot_equilpoints : bool or dict, optional
136
        If `True` (default) then plot equilibrium points based in the phase
137
        plot boundary. If set to a dict, pass on the key-value pairs in the
138
        dict as keywords to :func:`~control.phaseplot.equilpoints`.
139
    plot_separatrices : bool or dict, optional
140
        If `True` (default) then plot separatrices starting from each
141
        equilibrium point.  If set to a dict, pass on the key-value pairs
142
        in the dict as keywords to :func:`~control.phaseplot.separatrices`.
143
    rcParams : dict
144
        Override the default parameters used for generating plots.
145
        Default is set by config.default['ctrlplot.rcParams'].
146
    suppress_warnings : bool, optional
147
        If set to `True`, suppress warning messages in generating trajectories.
148
    title : str, optional
149
        Set the title of the plot.  Defaults to plot type and system name(s).
150

151
    """
152
    # Process arguments
153
    params = kwargs.get('params', None)
9✔
154
    sys = _create_system(sys, params)
9✔
155
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
9✔
156
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
157

158
    # Create axis if needed
159
    user_ax = ax
9✔
160
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
9✔
161

162
    # Create copy of kwargs for later checking to find unused arguments
163
    initial_kwargs = dict(kwargs)
9✔
164
    passed_kwargs = False
9✔
165

166
    # Utility function to create keyword arguments
167
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
9✔
168
        new_kwargs = dict(global_kwargs)
9✔
169
        new_kwargs.update(other_kwargs)
9✔
170
        if isinstance(local_kwargs, dict):
9✔
171
            new_kwargs.update(local_kwargs)
9✔
172
        return new_kwargs
9✔
173

174
    # Create list for storing outputs
175
    out = np.array([[], None, None], dtype=object)
9✔
176

177
    # Plot out the main elements
178
    if plot_streamlines:
9✔
179
        kwargs_local = _create_kwargs(
9✔
180
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
181
            ax=ax)
182
        out[0] += streamlines(
9✔
183
            sys, pointdata, timedata, _check_kwargs=False,
184
            suppress_warnings=suppress_warnings, **kwargs_local)
185

186
        # Get rid of keyword arguments handled by streamlines
187
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
9✔
188
                   'dir', 'params']:
189
            initial_kwargs.pop(kw, None)
9✔
190

191
    # Reset the gridspec for the remaining commands, if needed
192
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
9✔
193
        gridspec = None
×
194

195
    if plot_separatrices:
9✔
196
        kwargs_local = _create_kwargs(
9✔
197
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
198
        out[0] += separatrices(
9✔
199
            sys, pointdata, _check_kwargs=False, **kwargs_local)
200

201
        # Get rid of keyword arguments handled by separatrices
202
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
9✔
203
            initial_kwargs.pop(kw, None)
9✔
204

205
    if plot_vectorfield:
9✔
206
        kwargs_local = _create_kwargs(
×
207
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
208
        out[1] = vectorfield(
×
209
            sys, pointdata, _check_kwargs=False, **kwargs_local)
210

211
        # Get rid of keyword arguments handled by vectorfield
212
        for kw in ['color', 'params']:
×
213
            initial_kwargs.pop(kw, None)
×
214

215
    if plot_equilpoints:
9✔
216
        kwargs_local = _create_kwargs(
9✔
217
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
218
        out[2] = equilpoints(
9✔
219
            sys, pointdata, _check_kwargs=False, **kwargs_local)
220

221
        # Get rid of keyword arguments handled by equilpoints
222
        for kw in ['params']:
9✔
223
            initial_kwargs.pop(kw, None)
9✔
224

225
    # Make sure all keyword arguments were used
226
    if initial_kwargs:
9✔
227
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
9✔
228

229
    if user_ax is None:
9✔
230
        if title is None:
9✔
231
            title = f"Phase portrait for {sys.name}"
9✔
232
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
9✔
233
        ax.set_xlabel(sys.state_labels[0])
9✔
234
        ax.set_ylabel(sys.state_labels[1])
9✔
235
        plt.tight_layout()
9✔
236

237
    return ControlPlot(out, ax, fig)
9✔
238

239

240
def vectorfield(
9✔
241
        sys, pointdata, gridspec=None, ax=None, suppress_warnings=False,
242
        _check_kwargs=True, **kwargs):
243
    """Plot a vector field in the phase plane.
244

245
    This function plots a vector field for a two-dimensional state
246
    space system.
247

248
    Parameters
249
    ----------
250
    sys : NonlinearIOSystem or callable(t, x, ...)
251
        I/O system or function used to generate phase plane data.  If a
252
        function is given, the remaining arguments are drawn from the
253
        `params` keyword.
254
    pointdata : list or 2D array
255
        List of the form [xmin, xmax, ymin, ymax] describing the
256
        boundaries of the phase plot or an array of shape (N, 2)
257
        giving points of at which to plot the vector field.
258
    gridtype : str, optional
259
        The type of grid to use for generating initial conditions:
260
        'meshgrid' (default) generates a mesh of initial conditions within
261
        the specified boundaries, 'boxgrid' generates initial conditions
262
        along the edges of the boundary, 'circlegrid' generates a circle of
263
        initial conditions around each point in point data.
264
    gridspec : list, optional
265
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
266
        size of the grid in the x and y axes on which to generate points.
267
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
268
        specifying the radius and number of points around each point in the
269
        `pointdata` array.
270
    params : dict or list, optional
271
        Parameters to pass to system. For an I/O system, `params` should be
272
        a dict of parameters and values. For a callable, `params` should be
273
        dict with key 'args' and value given by a tuple (passed to callable).
274
    color : matplotlib color spec, optional
275
        Plot the vector field in the given color.
276
    ax : matplotlib.axes.Axes
277
        Use the given axes for the plot, otherwise use the current axes.
278

279
    Returns
280
    -------
281
    out : Quiver
282

283
    Other parameters
284
    ----------------
285
    rcParams : dict
286
        Override the default parameters used for generating plots.
287
        Default is set by config.default['ctrlplot.rcParams'].
288
    suppress_warnings : bool, optional
289
        If set to `True`, suppress warning messages in generating trajectories.
290

291
    """
292
    # Process keywords
293
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
294

295
    # Get system parameters
296
    params = kwargs.pop('params', None)
9✔
297

298
    # Create system from callable, if needed
299
    sys = _create_system(sys, params)
9✔
300

301
    # Determine the points on which to generate the vector field
302
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
303

304
    # Create axis if needed
305
    if ax is None:
9✔
306
        ax = plt.gca()
9✔
307

308
    # Set the plotting limits
309
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
310

311
    # Figure out the color to use
312
    color = _get_color(kwargs, ax=ax)
9✔
313

314
    # Make sure all keyword arguments were processed
315
    if _check_kwargs and kwargs:
9✔
316
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
317

318
    # Generate phase plane (quiver) data
319
    vfdata = np.zeros((points.shape[0], 4))
9✔
320
    sys._update_params(params)
9✔
321
    for i, x in enumerate(points):
9✔
322
        vfdata[i, :2] = x
9✔
323
        vfdata[i, 2:] = sys._rhs(0, x, np.zeros(sys.ninputs))
9✔
324

325
    with plt.rc_context(rcParams):
9✔
326
        out = ax.quiver(
9✔
327
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
328
            angles='xy', color=color)
329

330
    return out
9✔
331

332

333
def streamlines(
9✔
334
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
335
        ax=None, _check_kwargs=True, suppress_warnings=False, **kwargs):
336
    """Plot stream lines in the phase plane.
337

338
    This function plots stream lines for a two-dimensional state space
339
    system.
340

341
    Parameters
342
    ----------
343
    sys : NonlinearIOSystem or callable(t, x, ...)
344
        I/O system or function used to generate phase plane data.  If a
345
        function is given, the remaining arguments are drawn from the
346
        `params` keyword.
347
    pointdata : list or 2D array
348
        List of the form [xmin, xmax, ymin, ymax] describing the
349
        boundaries of the phase plot or an array of shape (N, 2)
350
        giving points of at which to plot the vector field.
351
    timedata : int or list of int
352
        Time to simulate each streamline.  If a list is given, a different
353
        time can be used for each initial condition in `pointdata`.
354
    gridtype : str, optional
355
        The type of grid to use for generating initial conditions:
356
        'meshgrid' (default) generates a mesh of initial conditions within
357
        the specified boundaries, 'boxgrid' generates initial conditions
358
        along the edges of the boundary, 'circlegrid' generates a circle of
359
        initial conditions around each point in point data.
360
    gridspec : list, optional
361
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
362
        size of the grid in the x and y axes on which to generate points.
363
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
364
        specifying the radius and number of points around each point in the
365
        `pointdata` array.
366
    dir : str, optional
367
        Direction to draw streamlines: 'forward' to flow forward in time
368
        from the reference points, 'reverse' to flow backward in time, or
369
        'both' to flow both forward and backward.  The amount of time to
370
        simulate in each direction is given by the ``timedata`` argument.
371
    params : dict or list, optional
372
        Parameters to pass to system. For an I/O system, `params` should be
373
        a dict of parameters and values. For a callable, `params` should be
374
        dict with key 'args' and value given by a tuple (passed to callable).
375
    color : str
376
        Plot the streamlines in the given color.
377
    ax : matplotlib.axes.Axes
378
        Use the given axes for the plot, otherwise use the current axes.
379

380
    Returns
381
    -------
382
    out : list of Line2D objects
383

384
    Other parameters
385
    ----------------
386
    rcParams : dict
387
        Override the default parameters used for generating plots.
388
        Default is set by config.default['ctrlplot.rcParams'].
389
    suppress_warnings : bool, optional
390
        If set to `True`, suppress warning messages in generating trajectories.
391

392
    """
393
    # Process keywords
394
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
395

396
    # Get system parameters
397
    params = kwargs.pop('params', None)
9✔
398

399
    # Create system from callable, if needed
400
    sys = _create_system(sys, params)
9✔
401

402
    # Parse the arrows keyword
403
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
404

405
    # Determine the points on which to generate the streamlines
406
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
9✔
407
    if dir is None:
9✔
408
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
9✔
409

410
    # Create axis if needed
411
    if ax is None:
9✔
412
        ax = plt.gca()
9✔
413

414
    # Set the axis limits
415
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
416

417
    # Figure out the color to use
418
    color = _get_color(kwargs, ax=ax)
9✔
419

420
    # Make sure all keyword arguments were processed
421
    if _check_kwargs and kwargs:
9✔
422
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
423

424
    # Create reverse time system, if needed
425
    if dir != 'forward':
9✔
426
        revsys = NonlinearIOSystem(
9✔
427
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
428
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
429
            outputs=sys.noutputs, params=sys.params)
430
    else:
431
        revsys = None
9✔
432

433
    # Generate phase plane (streamline) data
434
    out = []
9✔
435
    for i, X0 in enumerate(points):
9✔
436
        # Create the trajectory for this point
437
        timepts = _make_timepts(timedata, i)
9✔
438
        traj = _create_trajectory(
9✔
439
            sys, revsys, timepts, X0, params, dir,
440
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
441
            suppress_warnings=suppress_warnings)
442

443
        # Plot the trajectory (if there is one)
444
        if traj.shape[1] > 1:
9✔
445
            with plt.rc_context(rcParams):
9✔
446
                out += ax.plot(traj[0], traj[1], color=color)
9✔
447

448
                # Add arrows to the lines at specified intervals
449
                _add_arrows_to_line2D(
9✔
450
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
451
    return out
9✔
452

453

454
def equilpoints(
9✔
455
        sys, pointdata, gridspec=None, color='k', ax=None, _check_kwargs=True,
456
        **kwargs):
457
    """Plot equilibrium points in the phase plane.
458

459
    This function plots the equilibrium points for a planar dynamical system.
460

461
    Parameters
462
    ----------
463
    sys : NonlinearIOSystem or callable(t, x, ...)
464
        I/O system or function used to generate phase plane data. If a
465
        function is given, the remaining arguments are drawn from the
466
        `params` keyword.
467
    pointdata : list or 2D array
468
        List of the form [xmin, xmax, ymin, ymax] describing the
469
        boundaries of the phase plot or an array of shape (N, 2)
470
        giving points of at which to plot the vector field.
471
    gridtype : str, optional
472
        The type of grid to use for generating initial conditions:
473
        'meshgrid' (default) generates a mesh of initial conditions within
474
        the specified boundaries, 'boxgrid' generates initial conditions
475
        along the edges of the boundary, 'circlegrid' generates a circle of
476
        initial conditions around each point in point data.
477
    gridspec : list, optional
478
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
479
        size of the grid in the x and y axes on which to generate points.
480
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
481
        specifying the radius and number of points around each point in the
482
        `pointdata` array.
483
    params : dict or list, optional
484
        Parameters to pass to system. For an I/O system, `params` should be
485
        a dict of parameters and values. For a callable, `params` should be
486
        dict with key 'args' and value given by a tuple (passed to callable).
487
    color : str
488
        Plot the equilibrium points in the given color.
489
    ax : matplotlib.axes.Axes
490
        Use the given axes for the plot, otherwise use the current axes.
491

492
    Returns
493
    -------
494
    out : list of Line2D objects
495

496
    Other parameters
497
    ----------------
498
    rcParams : dict
499
        Override the default parameters used for generating plots.
500
        Default is set by config.default['ctrlplot.rcParams'].
501

502
    """
503
    # Process keywords
504
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
505

506
    # Get system parameters
507
    params = kwargs.pop('params', None)
9✔
508

509
    # Create system from callable, if needed
510
    sys = _create_system(sys, params)
9✔
511

512
    # Create axis if needed
513
    if ax is None:
9✔
514
        ax = plt.gca()
9✔
515

516
    # Set the axis limits
517
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
518

519
    # Determine the points on which to generate the vector field
520
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
521
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
522

523
    # Make sure all keyword arguments were processed
524
    if _check_kwargs and kwargs:
9✔
525
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
526

527
    # Search for equilibrium points
528
    equilpts = _find_equilpts(sys, points, params=params)
9✔
529

530
    # Plot the equilibrium points
531
    out = []
9✔
532
    for xeq in equilpts:
9✔
533
        with plt.rc_context(rcParams):
9✔
534
            out.append(
9✔
535
                ax.plot(xeq[0], xeq[1], marker='o', color=color))
536
    return out
9✔
537

538

539
def separatrices(
9✔
540
        sys, pointdata, timedata=None, gridspec=None, ax=None,
541
        _check_kwargs=True, suppress_warnings=False, **kwargs):
542
    """Plot separatrices in the phase plane.
543

544
    This function plots separatrices for a two-dimensional state space
545
    system.
546

547
    Parameters
548
    ----------
549
    sys : NonlinearIOSystem or callable(t, x, ...)
550
        I/O system or function used to generate phase plane data. If a
551
        function is given, the remaining arguments are drawn from the
552
        `params` keyword.
553
    pointdata : list or 2D array
554
        List of the form [xmin, xmax, ymin, ymax] describing the
555
        boundaries of the phase plot or an array of shape (N, 2)
556
        giving points of at which to plot the vector field.
557
    timedata : int or list of int
558
        Time to simulate each streamline.  If a list is given, a different
559
        time can be used for each initial condition in `pointdata`.
560
    gridtype : str, optional
561
        The type of grid to use for generating initial conditions:
562
        'meshgrid' (default) generates a mesh of initial conditions within
563
        the specified boundaries, 'boxgrid' generates initial conditions
564
        along the edges of the boundary, 'circlegrid' generates a circle of
565
        initial conditions around each point in point data.
566
    gridspec : list, optional
567
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
568
        size of the grid in the x and y axes on which to generate points.
569
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
570
        specifying the radius and number of points around each point in the
571
        `pointdata` array.
572
    params : dict or list, optional
573
        Parameters to pass to system. For an I/O system, `params` should be
574
        a dict of parameters and values. For a callable, `params` should be
575
        dict with key 'args' and value given by a tuple (passed to callable).
576
    color : matplotlib color spec, optional
577
        Plot the separatrics in the given color.  If a single color
578
        specification is given, this is used for both stable and unstable
579
        separatrices.  If a tuple is given, the first element is used as
580
        the color specification for stable separatrices and the second
581
        elmeent for unstable separatrices.
582
    ax : matplotlib.axes.Axes
583
        Use the given axes for the plot, otherwise use the current axes.
584

585
    Returns
586
    -------
587
    out : list of Line2D objects
588

589
    Other parameters
590
    ----------------
591
    rcParams : dict
592
        Override the default parameters used for generating plots.
593
        Default is set by config.default['ctrlplot.rcParams'].
594
    suppress_warnings : bool, optional
595
        If set to `True`, suppress warning messages in generating trajectories.
596

597
    """
598
    # Process keywords
599
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
600

601
    # Get system parameters
602
    params = kwargs.pop('params', None)
9✔
603

604
    # Create system from callable, if needed
605
    sys = _create_system(sys, params)
9✔
606

607
    # Parse the arrows keyword
608
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
609

610
    # Determine the initial states to use in searching for equilibrium points
611
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
612
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
613

614
    # Find the equilibrium points
615
    equilpts = _find_equilpts(sys, points, params=params)
9✔
616
    radius = config._get_param('phaseplot', 'separatrices_radius')
9✔
617

618
    # Create axis if needed
619
    if ax is None:
9✔
620
        ax = plt.gca()
9✔
621

622
    # Set the axis limits
623
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
624

625
    # Figure out the color to use for stable, unstable subspaces
626
    color = _get_color(kwargs)
9✔
627
    match color:
9✔
628
        case None:
9✔
629
            stable_color = 'r'
9✔
630
            unstable_color = 'b'
9✔
631
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
9✔
632
            pass
9✔
633
        case single_color:
9✔
634
            stable_color = unstable_color = color
9✔
635

636
    # Make sure all keyword arguments were processed
637
    if _check_kwargs and kwargs:
9✔
638
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
639

640
    # Create a "reverse time" system to use for simulation
641
    revsys = NonlinearIOSystem(
9✔
642
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
643
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
644
        outputs=sys.noutputs, params=sys.params)
645

646
    # Plot separatrices by flowing backwards in time along eigenspaces
647
    out = []
9✔
648
    for i, xeq in enumerate(equilpts):
9✔
649
        # Plot the equilibrium points
650
        with plt.rc_context(rcParams):
9✔
651
            out.append(
9✔
652
                ax.plot(xeq[0], xeq[1], marker='o', color='k'))
653

654
        # Figure out the linearization and eigenvectors
655
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
9✔
656

657
        # See if we have real eigenvalues (=> evecs are meaningful)
658
        if evals[0].imag > 0:
9✔
659
            continue
9✔
660

661
        # Create default list of time points
662
        if timedata is not None:
9✔
663
            timepts = _make_timepts(timedata, i)
9✔
664

665
        # Generate the traces
666
        for j, dir in enumerate(evecs.T):
9✔
667
            # Figure out time vector if not yet computed
668
            if timedata is None:
9✔
669
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
9✔
670
                timepts = np.linspace(0, timescale)
9✔
671

672
            # Run the trajectory starting in eigenvector directions
673
            for eps in [-radius, radius]:
9✔
674
                x0 = xeq + dir * eps
9✔
675
                if evals[j].real < 0:
9✔
676
                    traj = _create_trajectory(
9✔
677
                        sys, revsys, timepts, x0, params, 'reverse',
678
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
679
                        suppress_warnings=suppress_warnings)
680
                    color = stable_color
9✔
681
                    linestyle = '--'
9✔
682
                elif evals[j].real > 0:
9✔
683
                    traj = _create_trajectory(
9✔
684
                        sys, revsys, timepts, x0, params, 'forward',
685
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
686
                        suppress_warnings=suppress_warnings)
687
                    color = unstable_color
9✔
688
                    linestyle = '-'
9✔
689

690
                # Plot the trajectory (if there is one)
691
                if traj.shape[1] > 1:
9✔
692
                    with plt.rc_context(rcParams):
9✔
693
                        out.append(ax.plot(
9✔
694
                            traj[0], traj[1], color=color, linestyle=linestyle))
695

696
                    # Add arrows to the lines at specified intervals
697
                    with plt.rc_context(rcParams):
9✔
698
                        _add_arrows_to_line2D(
9✔
699
                            ax, out[-1][0], arrow_pos, arrowstyle=arrow_style,
700
                            dir=1)
701
    return out
9✔
702

703

704
#
705
# User accessible utility functions
706
#
707

708
# Utility function to generate boxgrid (in the form needed here)
709
def boxgrid(xvals, yvals):
9✔
710
    """Generate list of points along the edge of box.
711

712
    points = boxgrid(xvals, yvals) generates a list of points that
713
    corresponds to a grid given by the cross product of the x and y values.
714

715
    Parameters
716
    ----------
717
    xvals, yvals : 1D array-like
718
        Array of points defining the points on the lower and left edges of
719
        the box.
720

721
    Returns
722
    -------
723
    grid : 2D array
724
        Array with shape (p, 2) defining the points along the edges of the
725
        box, where p is the number of points around the edge.
726

727
    """
728
    return np.array(
9✔
729
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
730
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
731
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
732
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
733
    )
734

735

736
# Utility function to generate meshgrid (in the form needed here)
737
# TODO: add examples of using grid functions directly
738
def meshgrid(xvals, yvals):
9✔
739
    """Generate list of points forming a mesh.
740

741
    points = meshgrid(xvals, yvals) generates a list of points that
742
    corresponds to a grid given by the cross product of the x and y values.
743

744
    Parameters
745
    ----------
746
    xvals, yvals : 1D array-like
747
        Array of points defining the points on the lower and left edges of
748
        the box.
749

750
    Returns
751
    -------
752
    grid: 2D array
753
        Array of points with shape (n * m, 2) defining the mesh
754

755
    """
756
    xvals, yvals = np.meshgrid(xvals, yvals)
9✔
757
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
9✔
758
    grid[:, 0] = xvals.reshape(-1)
9✔
759
    grid[:, 1] = yvals.reshape(-1)
9✔
760

761
    return grid
9✔
762

763

764
# Utility function to generate circular grid
765
def circlegrid(centers, radius, num):
9✔
766
    """Generate list of points around a circle.
767

768
    points = circlegrid(centers, radius, num) generates a list of points
769
    that form a circle around a list of centers.
770

771
    Parameters
772
    ----------
773
    centers : 2D array-like
774
        Array of points with shape (p, 2) defining centers of the circles.
775
    radius : float
776
        Radius of the points to be generated around each center.
777
    num : int
778
        Number of points to generate around the circle.
779

780
    Returns
781
    -------
782
    grid: 2D array
783
        Array of points with shape (p * num, 2) defining the circles.
784

785
    """
786
    centers = np.atleast_2d(np.array(centers))
9✔
787
    grid = np.zeros((centers.shape[0] * num, 2))
9✔
788
    for i, center in enumerate(centers):
9✔
789
        grid[i * num: (i + 1) * num, :] = center + np.array([
9✔
790
            [radius * math.cos(theta), radius * math.sin(theta)] for
791
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
792
    return grid
9✔
793

794
#
795
# Internal utility functions
796
#
797

798
# Create a system from a callable
799
def _create_system(sys, params):
9✔
800
    if isinstance(sys, NonlinearIOSystem):
9✔
801
        if sys.nstates != 2:
9✔
802
            raise ValueError("system must be planar")
9✔
803
        return sys
9✔
804

805
    # Make sure that if params is present, it has 'args' key
806
    if params and not params.get('args', None):
9✔
807
        raise ValueError("params must be dict with key 'args'")
9✔
808

809
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
9✔
810
    _output = lambda t, x, u, params: np.array([])
9✔
811
    return NonlinearIOSystem(
9✔
812
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
813

814
# Set axis limits for the plot
815
def _set_axis_limits(ax, pointdata):
9✔
816
    # Get the current axis limits
817
    if ax.lines:
9✔
818
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
9✔
819
    else:
820
        # Nothing on the plot => always use new limits
821
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
9✔
822

823
    # Short utility function for updating axis limits
824
    def _update_limits(cur, new):
9✔
825
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
9✔
826

827
    # If we were passed a box, use that to update the limits
828
    if isinstance(pointdata, list) and len(pointdata) == 4:
9✔
829
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
9✔
830
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
9✔
831

832
    elif isinstance(pointdata, np.ndarray):
9✔
833
        pointdata = np.atleast_2d(pointdata)
9✔
834
        xlim = _update_limits(
9✔
835
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
836
        ylim = _update_limits(
9✔
837
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
838

839
    # Keep track of the largest dimension on the plot
840
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
9✔
841

842
    # Set the new limits
843
    ax.autoscale(enable=True, axis='x', tight=True)
9✔
844
    ax.autoscale(enable=True, axis='y', tight=True)
9✔
845
    ax.set_xlim(xlim)
9✔
846
    ax.set_ylim(ylim)
9✔
847

848
    return xlim, ylim, maxlim
9✔
849

850

851
# Find equilibrium points
852
def _find_equilpts(sys, points, params=None):
9✔
853
    equilpts = []
9✔
854
    for i, x0 in enumerate(points):
9✔
855
        # Look for an equilibrium point near this point
856
        xeq, ueq = find_eqpt(sys, x0, 0, params=params)
9✔
857

858
        if xeq is None:
9✔
859
            continue            # didn't find anything
9✔
860

861
        # See if we have already found this point
862
        seen = False
9✔
863
        for x in equilpts:
9✔
864
            if np.allclose(np.array(x), xeq):
9✔
865
                seen = True
9✔
866
        if seen:
9✔
867
            continue
9✔
868

869
        # Save a new point
870
        equilpts += [xeq.tolist()]
9✔
871

872
    return equilpts
9✔
873

874

875
def _make_points(pointdata, gridspec, gridtype):
9✔
876
    # Check to see what type of data we got
877
    if isinstance(pointdata, np.ndarray) and gridtype is None:
9✔
878
        pointdata = np.atleast_2d(pointdata)
9✔
879
        if pointdata.shape[1] == 2:
9✔
880
            # Given a list of points => no action required
881
            return pointdata, None
9✔
882

883
    # Utility function to parse (and check) input arguments
884
    def _parse_args(defsize):
9✔
885
        if gridspec is None:
9✔
886
            return defsize
9✔
887

888
        elif not isinstance(gridspec, (list, tuple)) or \
9✔
889
             len(gridspec) != len(defsize):
890
            raise ValueError("invalid grid specification")
9✔
891

892
        return gridspec
9✔
893

894
    # Generate points based on grid type
895
    match gridtype:
9✔
896
        case 'boxgrid' | None:
9✔
897
            gridspec = _parse_args([6, 4])
9✔
898
            points = boxgrid(
9✔
899
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
900
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
901

902
        case 'meshgrid':
9✔
903
            gridspec = _parse_args([9, 6])
9✔
904
            points = meshgrid(
9✔
905
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
906
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
907

908
        case 'circlegrid':
9✔
909
            gridspec = _parse_args((0.5, 10))
9✔
910
            if isinstance(pointdata, np.ndarray):
9✔
911
                # Create circles around each point
912
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
9✔
913
            else:
914
                # Create circle around center of the plot
915
                points = circlegrid(
9✔
916
                    np.array(
917
                        [(pointdata[0] + pointdata[1]) / 2,
918
                         (pointdata[0] + pointdata[1]) / 2]),
919
                    gridspec[0], gridspec[1])
920

921
        case _:
9✔
922
            raise ValueError(f"unknown grid type '{gridtype}'")
9✔
923

924
    return points, gridspec
9✔
925

926

927
def _parse_arrow_keywords(kwargs):
9✔
928
    # Get values for params (and pop from list to allow keyword use in plot)
929
    # TODO: turn this into a utility function (shared with nyquist_plot?)
930
    arrows = config._get_param(
9✔
931
        'phaseplot', 'arrows', kwargs, None, pop=True)
932
    arrow_size = config._get_param(
9✔
933
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
934
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
9✔
935

936
    # Parse the arrows keyword
937
    if not arrows:
9✔
938
        arrow_pos = []
×
939
    elif isinstance(arrows, int):
9✔
940
        N = arrows
9✔
941
        # Space arrows out, starting midway along each "region"
942
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
9✔
943
    elif isinstance(arrows, (list, np.ndarray)):
×
944
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
945
    else:
946
        raise ValueError("unknown or unsupported arrow location")
×
947

948
    # Set the arrow style
949
    if arrow_style is None:
9✔
950
        arrow_style = mpl.patches.ArrowStyle(
9✔
951
            'simple', head_width=int(2 * arrow_size / 3),
952
            head_length=arrow_size)
953

954
    return arrow_pos, arrow_style
9✔
955

956

957
# TODO: move to ctrlplot?
958
def _create_trajectory(
9✔
959
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
960
        gridtype=None, gridspec=None, xlim=None, ylim=None):
961
    # Comput ethe forward trajectory
962
    if dir == 'forward' or dir == 'both':
9✔
963
        fwdresp = input_output_response(
9✔
964
            sys, timepts, X0=X0, params=params, ignore_errors=True)
965
        if not fwdresp.success and not suppress_warnings:
9✔
966
            warnings.warn(f"{X0=}, {fwdresp.message}")
9✔
967

968
    # Compute the reverse trajectory
969
    if dir == 'reverse' or dir == 'both':
9✔
970
        revresp = input_output_response(
9✔
971
            revsys, timepts, X0=X0, params=params, ignore_errors=True)
972
        if not revresp.success and not suppress_warnings:
9✔
973
            warnings.warn(f"{X0=}, {revresp.message}")
×
974

975
    # Create the trace to plot
976
    if dir == 'forward':
9✔
977
        traj = fwdresp.states
9✔
978
    elif dir == 'reverse':
9✔
979
        traj = revresp.states[:, ::-1]
9✔
980
    elif dir == 'both':
9✔
981
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
9✔
982

983
    # Remove points outside the window (keep first point beyond boundary)
984
    inrange = np.asarray(
9✔
985
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
986
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
987
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
9✔
988
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
9✔
989

990
    return traj[:, inrange]
9✔
991

992

993
def _make_timepts(timepts, i):
9✔
994
    if timepts is None:
9✔
995
        return np.linspace(0, 1)
9✔
996
    elif isinstance(timepts, (int, float)):
9✔
997
        return np.linspace(0, timepts)
9✔
998
    elif timepts.ndim == 2:
×
999
        return timepts[i]
×
1000
    return timepts
×
1001

1002

1003
#
1004
# Legacy phase plot function
1005
#
1006
# Author: Richard Murray
1007
# Date: 24 July 2011, converted from MATLAB version (2002); based on
1008
# a version by Kristi Morgansen
1009
#
1010
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
9✔
1011
               lingrid=None, lintime=None, logtime=None, timepts=None,
1012
               parms=None, params=(), tfirst=False, verbose=True):
1013

1014
    """(legacy) Phase plot for 2D dynamical systems.
1015

1016
    .. deprecated:: 0.10.1
1017
        This function is deprecated; use :func:`phase_plane_plot` instead.
1018

1019
    Produces a vector field or stream line plot for a planar system.  This
1020
    function has been replaced by the :func:`~control.phase_plane_map` and
1021
    :func:`~control.phase_plane_plot` functions.
1022

1023
    Call signatures:
1024
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
1025
      phase_plot(func, X, Y, scale, ...) - scale arrows
1026
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
1027
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
1028
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
1029
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
1030

1031
    Parameters
1032
    ----------
1033
    func : callable(x, t, ...)
1034
        Computes the time derivative of y (compatible with odeint).  The
1035
        function should be the same for as used for :mod:`scipy.integrate`.
1036
        Namely, it should be a function of the form dxdt = F(t, x) that
1037
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1038
        dimension 2.
1039
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1040
        Two 3-element sequences specifying x and y coordinates of a
1041
        grid.  These arguments are passed to linspace and meshgrid to
1042
        generate the points at which the vector field is plotted.  If
1043
        absent (or None), the vector field is not plotted.
1044
    scale: float, optional
1045
        Scale size of arrows; default = 1
1046
    X0: ndarray of initial conditions, optional
1047
        List of initial conditions from which streamlines are plotted.
1048
        Each initial condition should be a pair of numbers.
1049
    T: array-like or number, optional
1050
        Length of time to run simulations that generate streamlines.
1051
        If a single number, the same simulation time is used for all
1052
        initial conditions.  Otherwise, should be a list of length
1053
        len(X0) that gives the simulation time for each initial
1054
        condition.  Default value = 50.
1055
    lingrid : integer or 2-tuple of integers, optional
1056
        Argument is either N or (N, M).  If X0 is given and X, Y are missing,
1057
        a grid of arrows is produced using the limits of the initial
1058
        conditions, with N grid points in each dimension or N grid points in x
1059
        and M grid points in y.
1060
    lintime : integer or tuple (integer, float), optional
1061
        If a single integer N is given, draw N arrows using equally space time
1062
        points.  If a tuple (N, lambda) is given, draw N arrows using
1063
        exponential time constant lambda
1064
    timepts : array-like, optional
1065
        Draw arrows at the given list times [t1, t2, ...]
1066
    tfirst : bool, optional
1067
        If True, call `func` with signature `func(t, x, ...)`.
1068
    params: tuple, optional
1069
        List of parameters to pass to vector field: `func(x, t, *params)`
1070

1071
    See also
1072
    --------
1073
    box_grid : construct box-shaped grid of initial conditions
1074

1075
    """
1076
    # Generate a deprecation warning
1077
    warnings.warn(
9✔
1078
        "phase_plot() is deprecated; use phase_plane_plot() instead",
1079
        FutureWarning)
1080

1081
    #
1082
    # Figure out ranges for phase plot (argument processing)
1083
    #
1084
    #! TODO: need to add error checking to arguments
1085
    #! TODO: think through proper action if multiple options are given
1086
    #
1087
    autoFlag = False
9✔
1088
    logtimeFlag = False
9✔
1089
    timeptsFlag = False
9✔
1090
    Narrows = 0
9✔
1091

1092
    # Get parameters to pass to function
1093
    if parms:
9✔
1094
        warnings.warn(
9✔
1095
            f"keyword 'parms' is deprecated; use 'params'", FutureWarning)
1096
        if params:
9✔
1097
            raise ControlArgument(f"duplicate keywords 'parms' and 'params'")
×
1098
        else:
1099
            params = parms
9✔
1100

1101
    if lingrid is not None:
9✔
1102
        autoFlag = True
9✔
1103
        Narrows = lingrid
9✔
1104
        if (verbose):
9✔
1105
            print('Using auto arrows\n')
×
1106

1107
    elif logtime is not None:
9✔
1108
        logtimeFlag = True
9✔
1109
        Narrows = logtime[0]
9✔
1110
        timefactor = logtime[1]
9✔
1111
        if (verbose):
9✔
1112
            print('Using logtime arrows\n')
×
1113

1114
    elif timepts is not None:
9✔
1115
        timeptsFlag = True
9✔
1116
        Narrows = len(timepts)
9✔
1117

1118
    # Figure out the set of points for the quiver plot
1119
    #! TODO: Add sanity checks
1120
    elif X is not None and Y is not None:
9✔
1121
        x1, x2 = np.meshgrid(
9✔
1122
            np.linspace(X[0], X[1], X[2]),
1123
            np.linspace(Y[0], Y[1], Y[2]))
1124
        Narrows = len(x1)
9✔
1125

1126
    else:
1127
        # If we weren't given any grid points, don't plot arrows
1128
        Narrows = 0
9✔
1129

1130
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
9✔
1131
        # Now calculate the vector field at those points
1132
        (nr,nc) = x1.shape
9✔
1133
        dx = np.empty((nr, nc, 2))
9✔
1134
        for i in range(nr):
9✔
1135
            for j in range(nc):
9✔
1136
                if tfirst:
9✔
1137
                    dx[i, j, :] = np.squeeze(
×
1138
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1139
                else:
1140
                    dx[i, j, :] = np.squeeze(
9✔
1141
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1142

1143
        # Plot the quiver plot
1144
        #! TODO: figure out arguments to make arrows show up correctly
1145
        if scale is None:
9✔
1146
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1147
        elif (scale != 0):
9✔
1148
            #! TODO: optimize parameters for arrows
1149
            #! TODO: figure out arguments to make arrows show up correctly
1150
            xy = plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
9✔
1151
                            dx[:,:,1]*np.abs(scale), angles='xy')
1152
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1153

1154
        #! TODO: Tweak the shape of the plot
1155
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1156
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1157
        plt.xlabel('x1'); plt.ylabel('x2')
9✔
1158

1159
    # See if we should also generate the streamlines
1160
    if X0 is None or len(X0) == 0:
9✔
1161
        return
9✔
1162

1163
    # Convert initial conditions to a numpy array
1164
    X0 = np.array(X0)
9✔
1165
    (nr, nc) = np.shape(X0)
9✔
1166

1167
    # Generate some empty matrices to keep arrow information
1168
    x1 = np.empty((nr, Narrows))
9✔
1169
    x2 = np.empty((nr, Narrows))
9✔
1170
    dx = np.empty((nr, Narrows, 2))
9✔
1171

1172
    # See if we were passed a simulation time
1173
    if T is None:
9✔
1174
        T = 50
9✔
1175

1176
    # Parse the time we were passed
1177
    TSPAN = T
9✔
1178
    if isinstance(T, (int, float)):
9✔
1179
        TSPAN = np.linspace(0, T, 100)
9✔
1180

1181
    # Figure out the limits for the plot
1182
    if scale is None:
9✔
1183
        # Assume that the current axis are set as we want them
1184
        alim = plt.axis()
×
1185
        xmin = alim[0]; xmax = alim[1]
×
1186
        ymin = alim[2]; ymax = alim[3]
×
1187
    else:
1188
        # Use the maximum extent of all trajectories
1189
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
9✔
1190
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
9✔
1191

1192
    # Generate the streamlines for each initial condition
1193
    for i in range(nr):
9✔
1194
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
9✔
1195
        time = TSPAN
9✔
1196

1197
        plt.plot(state[:,0], state[:,1])
9✔
1198
        #! TODO: add back in colors for stream lines
1199
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1200
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1201

1202
        # Plot arrows if quiver parameters were 'auto'
1203
        if autoFlag or logtimeFlag or timeptsFlag:
9✔
1204
            # Compute the locations of the arrows
1205
            #! TODO: check this logic to make sure it works in python
1206
            for j in range(Narrows):
9✔
1207

1208
                # Figure out starting index; headless arrows start at 0
1209
                k = -1 if scale is None else 0
9✔
1210

1211
                # Figure out what time index to use for the next point
1212
                if autoFlag:
9✔
1213
                    # Use a linear scaling based on ODE time vector
1214
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1215
                elif logtimeFlag:
9✔
1216
                    # Use an exponential time vector
1217
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1218
                    tarr = _find(time < (j-k) / timefactor)
9✔
1219
                    tind = tarr[-1] if len(tarr) else 0
9✔
1220
                elif timeptsFlag:
9✔
1221
                    # Use specified time points
1222
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1223
                    tarr = _find(time < timepts[j])
9✔
1224
                    tind = tarr[-1] if len(tarr) else 0
9✔
1225

1226
                # For tailless arrows, skip the first point
1227
                if tind == 0 and scale is None:
9✔
1228
                    continue
×
1229

1230
                # Figure out the arrow at this point on the curve
1231
                x1[i,j] = state[tind, 0]
9✔
1232
                x2[i,j] = state[tind, 1]
9✔
1233

1234
                # Skip arrows outside of initial condition box
1235
                if (scale is not None or
9✔
1236
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1237
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1238
                    if tfirst:
9✔
1239
                        pass
×
1240
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1241
                    else:
1242
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
9✔
1243
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
9✔
1244
                else:
1245
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1246

1247
    # Set the plot shape before plotting arrows to avoid warping
1248
    # a=gca
1249
    # if (scale != None):
1250
    #     set(a,'DataAspectRatio', [1,1,1])
1251
    # if (xmin != xmax and ymin != ymax):
1252
    #     plt.axis([xmin, xmax, ymin, ymax])
1253
    # set(a, 'Box', 'on')
1254

1255
    # Plot arrows on the streamlines
1256
    if scale is None and Narrows > 0:
9✔
1257
        # Use a tailless arrow
1258
        #! TODO: figure out arguments to make arrows show up correctly
1259
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1260
    elif scale != 0 and Narrows > 0:
9✔
1261
        #! TODO: figure out arguments to make arrows show up correctly
1262
        xy = plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
9✔
1263
                        angles='xy')
1264
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1265
        # set(xy, 'AutoScale', 'off')
1266
        # set(xy, 'AutoScaleFactor', 0)
1267

1268
    if scale < 0:
9✔
1269
        bp = plt.plot(x1, x2, 'b.');        # add dots at base
×
1270
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1271

1272

1273
# Utility function for generating initial conditions around a box
1274
def box_grid(xlimp, ylimp):
9✔
1275
    """box_grid   generate list of points on edge of box
1276

1277
    .. deprecated:: 0.10.0
1278
        Use :func:`phaseplot.boxgrid` instead.
1279

1280
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1281
    list of points that correspond to a uniform grid at the end of the
1282
    box defined by the corners [xmin ymin] and [xmax ymax].
1283

1284
    """
1285

1286
    # Generate a deprecation warning
1287
    warnings.warn(
×
1288
        "box_grid() is deprecated; use phaseplot.boxgrid() instead",
1289
        FutureWarning)
1290

1291
    return boxgrid(
×
1292
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1293
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1294

1295

1296
# TODO: rename to something more useful (or remove??)
1297
def _find(condition):
9✔
1298
    """Returns indices where ravel(a) is true.
1299
    Private implementation of deprecated matplotlib.mlab.find
1300
    """
1301
    return np.nonzero(np.ravel(condition))[0]
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc