• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

python-control / python-control / 10312443515

09 Aug 2024 02:07AM UTC coverage: 94.694% (+0.04%) from 94.65%
10312443515

push

github

web-flow
Merge pull request #1034 from murrayrm/ctrlplot_updates-27Jun2024

Control plot refactoring for consistent functionality

9137 of 9649 relevant lines covered (94.69%)

8.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.44
control/phaseplot.py
1
# phaseplot.py - generate 2D phase portraits
2
#
3
# Author: Richard M. Murray
4
# Date:   23 Mar 2024 (legacy version information below)
5
#
6
# TODO
7
# * Allow multiple timepoints (and change timespec name to T?)
8
# * Update linestyles (color -> linestyle?)
9
# * Check for keyword compatibility with other plot routines
10
# * Set up configuration parameters (nyquist --> phaseplot)
11

12
"""Module for generating 2D phase plane plots.
13

14
The :mod:`control.phaseplot` module contains functions for generating 2D
15
phase plots. The base function for creating phase plane portraits is
16
:func:`~control.phase_plane_plot`, which generates a phase plane portrait
17
for a 2 state I/O system (with no inputs).  In addition, several other
18
functions are available to create customized phase plane plots:
19

20
* boxgrid: Generate a list of points along the edge of a box
21
* circlegrid: Generate list of points around a circle
22
* equilpoints: Plot equilibrium points in the phase plane
23
* meshgrid: Generate a list of points forming a mesh
24
* separatrices: Plot separatrices in the phase plane
25
* streamlines: Plot stream lines in the phase plane
26
* vectorfield: Plot a vector field in the phase plane
27

28
"""
29

30
import math
9✔
31
import warnings
9✔
32

33
import matplotlib as mpl
9✔
34
import matplotlib.pyplot as plt
9✔
35
import numpy as np
9✔
36
from scipy.integrate import odeint
9✔
37

38
from . import config
9✔
39
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _get_color, \
9✔
40
    _process_ax_keyword, _update_plot_title
41
from .exception import ControlNotImplemented
9✔
42
from .nlsys import NonlinearIOSystem, find_eqpt, input_output_response
9✔
43

44
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
9✔
45

46
# Default values for module parameter variables
47
_phaseplot_defaults = {
9✔
48
    'phaseplot.arrows': 2,                  # number of arrows around curve
49
    'phaseplot.arrow_size': 8,              # pixel size for arrows
50
    'phaseplot.separatrices_radius': 0.1    # initial radius for separatrices
51
}
52

53
def phase_plane_plot(
9✔
54
        sys, pointdata=None, timedata=None, gridtype=None, gridspec=None,
55
        plot_streamlines=True, plot_vectorfield=False, plot_equilpoints=True,
56
        plot_separatrices=True, ax=None, suppress_warnings=False, title=None,
57
        **kwargs
58
):
59
    """Plot phase plane diagram.
60

61
    This function plots phase plane data, including vector fields, stream
62
    lines, equilibrium points, and contour curves.
63

64
    Parameters
65
    ----------
66
    sys : NonlinearIOSystem or callable(t, x, ...)
67
        I/O system or function used to generate phase plane data. If a
68
        function is given, the remaining arguments are drawn from the
69
        `params` keyword.
70
    pointdata : list or 2D array
71
        List of the form [xmin, xmax, ymin, ymax] describing the
72
        boundaries of the phase plot or an array of shape (N, 2)
73
        giving points of at which to plot the vector field.
74
    timedata : int or list of int
75
        Time to simulate each streamline.  If a list is given, a different
76
        time can be used for each initial condition in `pointdata`.
77
    gridtype : str, optional
78
        The type of grid to use for generating initial conditions:
79
        'meshgrid' (default) generates a mesh of initial conditions within
80
        the specified boundaries, 'boxgrid' generates initial conditions
81
        along the edges of the boundary, 'circlegrid' generates a circle of
82
        initial conditions around each point in point data.
83
    gridspec : list, optional
84
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
85
        size of the grid in the x and y axes on which to generate points.
86
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
87
        specifying the radius and number of points around each point in the
88
        `pointdata` array.
89
    params : dict, optional
90
        Parameters to pass to system. For an I/O system, `params` should be
91
        a dict of parameters and values. For a callable, `params` should be
92
        dict with key 'args' and value given by a tuple (passed to callable).
93
    color : matplotlib color spec, optional
94
        Plot all elements in the given color (use `plot_<fcn>={'color': c}`
95
        to set the color in one element of the phase plot.
96
    ax : matplotlib.axes.Axes, optional
97
        The matplotlib axes to draw the figure on.  If not specified and
98
        the current figure has a single axes, that axes is used.
99
        Otherwise, a new figure is created.
100

101
    Returns
102
    -------
103
    cplt : :class:`ControlPlot` object
104
        Object containing the data that were plotted:
105

106
          * cplt.lines: array of list of :class:`matplotlib.artist.Artist`
107
            objects:
108

109
              - lines[0] = list of Line2D objects (streamlines, separatrices).
110
              - lines[1] = Quiver object (vector field arrows).
111
              - lines[2] = list of Line2D objects (equilibrium points).
112

113
          * cplt.axes: 2D array of :class:`matplotlib.axes.Axes` for the plot.
114

115
          * cplt.figure: :class:`matplotlib.figure.Figure` containing the plot.
116

117
        See :class:`ControlPlot` for more detailed information.
118

119

120
    Other parameters
121
    ----------------
122
    plot_streamlines : bool or dict, optional
123
        If `True` (default) then plot streamlines based on the pointdata
124
        and gridtype.  If set to a dict, pass on the key-value pairs in
125
        the dict as keywords to :func:`~control.phaseplot.streamlines`.
126
    plot_vectorfield : bool or dict, optional
127
        If `True` (default) then plot the vector field based on the pointdata
128
        and gridtype.  If set to a dict, pass on the key-value pairs in
129
        the dict as keywords to :func:`~control.phaseplot.vectorfield`.
130
    plot_equilpoints : bool or dict, optional
131
        If `True` (default) then plot equilibrium points based in the phase
132
        plot boundary. If set to a dict, pass on the key-value pairs in the
133
        dict as keywords to :func:`~control.phaseplot.equilpoints`.
134
    plot_separatrices : bool or dict, optional
135
        If `True` (default) then plot separatrices starting from each
136
        equilibrium point.  If set to a dict, pass on the key-value pairs
137
        in the dict as keywords to :func:`~control.phaseplot.separatrices`.
138
    suppress_warnings : bool, optional
139
        If set to `True`, suppress warning messages in generating trajectories.
140
    title : str, optional
141
        Set the title of the plot.  Defaults to plot type and system name(s).
142

143
    """
144
    # Process arguments
145
    params = kwargs.get('params', None)
9✔
146
    sys = _create_system(sys, params)
9✔
147
    pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
9✔
148
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
149

150
    # Create axis if needed
151
    user_ax = ax
9✔
152
    fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
9✔
153

154
    # Create copy of kwargs for later checking to find unused arguments
155
    initial_kwargs = dict(kwargs)
9✔
156
    passed_kwargs = False
9✔
157

158
    # Utility function to create keyword arguments
159
    def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
9✔
160
        new_kwargs = dict(global_kwargs)
9✔
161
        new_kwargs.update(other_kwargs)
9✔
162
        if isinstance(local_kwargs, dict):
9✔
163
            new_kwargs.update(local_kwargs)
9✔
164
        return new_kwargs
9✔
165

166
    # Create list for storing outputs
167
    out = np.array([[], None, None], dtype=object)
9✔
168

169
    # Plot out the main elements
170
    if plot_streamlines:
9✔
171
        kwargs_local = _create_kwargs(
9✔
172
            kwargs, plot_streamlines, gridspec=gridspec, gridtype=gridtype,
173
            ax=ax)
174
        out[0] += streamlines(
9✔
175
            sys, pointdata, timedata, check_kwargs=False,
176
            suppress_warnings=suppress_warnings, **kwargs_local)
177

178
        # Get rid of keyword arguments handled by streamlines
179
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'color',
9✔
180
                   'dir', 'params']:
181
            initial_kwargs.pop(kw, None)
9✔
182

183
    # Reset the gridspec for the remaining commands, if needed
184
    if gridtype not in [None, 'boxgrid', 'meshgrid']:
9✔
185
        gridspec = None
×
186

187
    if plot_separatrices:
9✔
188
        kwargs_local = _create_kwargs(
9✔
189
            kwargs, plot_separatrices, gridspec=gridspec, ax=ax)
190
        out[0] += separatrices(
9✔
191
            sys, pointdata, check_kwargs=False, **kwargs_local)
192

193
        # Get rid of keyword arguments handled by separatrices
194
        for kw in ['arrows', 'arrow_size', 'arrow_style', 'params']:
9✔
195
            initial_kwargs.pop(kw, None)
9✔
196

197
    if plot_vectorfield:
9✔
198
        kwargs_local = _create_kwargs(
×
199
            kwargs, plot_vectorfield, gridspec=gridspec, ax=ax)
200
        out[1] = vectorfield(
×
201
            sys, pointdata, check_kwargs=False, **kwargs_local)
202

203
        # Get rid of keyword arguments handled by vectorfield
204
        for kw in ['color', 'params']:
×
205
            initial_kwargs.pop(kw, None)
×
206

207
    if plot_equilpoints:
9✔
208
        kwargs_local = _create_kwargs(
9✔
209
            kwargs, plot_equilpoints, gridspec=gridspec, ax=ax)
210
        out[2] = equilpoints(
9✔
211
            sys, pointdata, check_kwargs=False, **kwargs_local)
212

213
        # Get rid of keyword arguments handled by equilpoints
214
        for kw in ['params']:
9✔
215
            initial_kwargs.pop(kw, None)
9✔
216

217
    # Make sure all keyword arguments were used
218
    if initial_kwargs:
9✔
219
        raise TypeError("unrecognized keywords: ", str(initial_kwargs))
9✔
220

221
    if user_ax is None:
9✔
222
        if title is None:
9✔
223
            title = f"Phase portrait for {sys.name}"
9✔
224
        _update_plot_title(title, use_existing=False, rcParams=rcParams)
9✔
225
        ax.set_xlabel(sys.state_labels[0])
9✔
226
        ax.set_ylabel(sys.state_labels[1])
9✔
227
        plt.tight_layout()
9✔
228

229
    return ControlPlot(out, ax, fig)
9✔
230

231

232
def vectorfield(
9✔
233
        sys, pointdata, gridspec=None, ax=None, suppress_warnings=False,
234
        check_kwargs=True, **kwargs):
235
    """Plot a vector field in the phase plane.
236

237
    This function plots a vector field for a two-dimensional state
238
    space system.
239

240
    Parameters
241
    ----------
242
    sys : NonlinearIOSystem or callable(t, x, ...)
243
        I/O system or function used to generate phase plane data.  If a
244
        function is given, the remaining arguments are drawn from the
245
        `params` keyword.
246
    pointdata : list or 2D array
247
        List of the form [xmin, xmax, ymin, ymax] describing the
248
        boundaries of the phase plot or an array of shape (N, 2)
249
        giving points of at which to plot the vector field.
250
    gridtype : str, optional
251
        The type of grid to use for generating initial conditions:
252
        'meshgrid' (default) generates a mesh of initial conditions within
253
        the specified boundaries, 'boxgrid' generates initial conditions
254
        along the edges of the boundary, 'circlegrid' generates a circle of
255
        initial conditions around each point in point data.
256
    gridspec : list, optional
257
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
258
        size of the grid in the x and y axes on which to generate points.
259
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
260
        specifying the radius and number of points around each point in the
261
        `pointdata` array.
262
    params : dict or list, optional
263
        Parameters to pass to system. For an I/O system, `params` should be
264
        a dict of parameters and values. For a callable, `params` should be
265
        dict with key 'args' and value given by a tuple (passed to callable).
266
    color : matplotlib color spec, optional
267
        Plot the vector field in the given color.
268
    ax : matplotlib.axes.Axes
269
        Use the given axes for the plot, otherwise use the current axes.
270

271
    Returns
272
    -------
273
    out : Quiver
274

275
    Other parameters
276
    ----------------
277
    suppress_warnings : bool, optional
278
        If set to `True`, suppress warning messages in generating trajectories.
279

280
    """
281
    # Process keywords
282
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
283

284
    # Get system parameters
285
    params = kwargs.pop('params', None)
9✔
286

287
    # Create system from callable, if needed
288
    sys = _create_system(sys, params)
9✔
289

290
    # Determine the points on which to generate the vector field
291
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
292

293
    # Create axis if needed
294
    if ax is None:
9✔
295
        ax = plt.gca()
9✔
296

297
    # Set the plotting limits
298
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
299

300
    # Figure out the color to use
301
    color = _get_color(kwargs, ax=ax)
9✔
302

303
    # Make sure all keyword arguments were processed
304
    if check_kwargs and kwargs:
9✔
305
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
306

307
    # Generate phase plane (quiver) data
308
    vfdata = np.zeros((points.shape[0], 4))
9✔
309
    sys._update_params(params)
9✔
310
    for i, x in enumerate(points):
9✔
311
        vfdata[i, :2] = x
9✔
312
        vfdata[i, 2:] = sys._rhs(0, x, 0)
9✔
313

314
    with plt.rc_context(rcParams):
9✔
315
        out = ax.quiver(
9✔
316
            vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
317
            angles='xy', color=color)
318

319
    return out
9✔
320

321

322
def streamlines(
9✔
323
        sys, pointdata, timedata=1, gridspec=None, gridtype=None, dir=None,
324
        ax=None, check_kwargs=True, suppress_warnings=False, **kwargs):
325
    """Plot stream lines in the phase plane.
326

327
    This function plots stream lines for a two-dimensional state space
328
    system.
329

330
    Parameters
331
    ----------
332
    sys : NonlinearIOSystem or callable(t, x, ...)
333
        I/O system or function used to generate phase plane data.  If a
334
        function is given, the remaining arguments are drawn from the
335
        `params` keyword.
336
    pointdata : list or 2D array
337
        List of the form [xmin, xmax, ymin, ymax] describing the
338
        boundaries of the phase plot or an array of shape (N, 2)
339
        giving points of at which to plot the vector field.
340
    timedata : int or list of int
341
        Time to simulate each streamline.  If a list is given, a different
342
        time can be used for each initial condition in `pointdata`.
343
    gridtype : str, optional
344
        The type of grid to use for generating initial conditions:
345
        'meshgrid' (default) generates a mesh of initial conditions within
346
        the specified boundaries, 'boxgrid' generates initial conditions
347
        along the edges of the boundary, 'circlegrid' generates a circle of
348
        initial conditions around each point in point data.
349
    gridspec : list, optional
350
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
351
        size of the grid in the x and y axes on which to generate points.
352
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
353
        specifying the radius and number of points around each point in the
354
        `pointdata` array.
355
    params : dict or list, optional
356
        Parameters to pass to system. For an I/O system, `params` should be
357
        a dict of parameters and values. For a callable, `params` should be
358
        dict with key 'args' and value given by a tuple (passed to callable).
359
    color : str
360
        Plot the streamlines in the given color.
361
    ax : matplotlib.axes.Axes
362
        Use the given axes for the plot, otherwise use the current axes.
363

364
    Returns
365
    -------
366
    out : list of Line2D objects
367

368
    Other parameters
369
    ----------------
370
    suppress_warnings : bool, optional
371
        If set to `True`, suppress warning messages in generating trajectories.
372

373
    """
374
    # Process keywords
375
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
376

377
    # Get system parameters
378
    params = kwargs.pop('params', None)
9✔
379

380
    # Create system from callable, if needed
381
    sys = _create_system(sys, params)
9✔
382

383
    # Parse the arrows keyword
384
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
385

386
    # Determine the points on which to generate the streamlines
387
    points, gridspec = _make_points(pointdata, gridspec, gridtype=gridtype)
9✔
388
    if dir is None:
9✔
389
        dir = 'both' if gridtype == 'meshgrid' else 'forward'
9✔
390

391
    # Create axis if needed
392
    if ax is None:
9✔
393
        ax = plt.gca()
9✔
394

395
    # Set the axis limits
396
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
397

398
    # Figure out the color to use
399
    color = _get_color(kwargs, ax=ax)
9✔
400

401
    # Make sure all keyword arguments were processed
402
    if check_kwargs and kwargs:
9✔
403
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
404

405
    # Create reverse time system, if needed
406
    if dir != 'forward':
9✔
407
        revsys = NonlinearIOSystem(
9✔
408
            lambda t, x, u, params: -np.asarray(sys.updfcn(t, x, u, params)),
409
            sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
410
            outputs=sys.noutputs, params=sys.params)
411
    else:
412
        revsys = None
9✔
413

414
    # Generate phase plane (streamline) data
415
    out = []
9✔
416
    for i, X0 in enumerate(points):
9✔
417
        # Create the trajectory for this point
418
        timepts = _make_timepts(timedata, i)
9✔
419
        traj = _create_trajectory(
9✔
420
            sys, revsys, timepts, X0, params, dir,
421
            gridtype=gridtype, gridspec=gridspec, xlim=xlim, ylim=ylim,
422
            suppress_warnings=suppress_warnings)
423

424
        # Plot the trajectory (if there is one)
425
        if traj.shape[1] > 1:
9✔
426
            with plt.rc_context(rcParams):
9✔
427
                out += ax.plot(traj[0], traj[1], color=color)
9✔
428

429
                # Add arrows to the lines at specified intervals
430
                _add_arrows_to_line2D(
9✔
431
                    ax, out[-1], arrow_pos, arrowstyle=arrow_style, dir=1)
432
    return out
9✔
433

434

435
def equilpoints(
9✔
436
        sys, pointdata, gridspec=None, color='k', ax=None, check_kwargs=True,
437
        **kwargs):
438
    """Plot equilibrium points in the phase plane.
439

440
    This function plots the equilibrium points for a planar dynamical system.
441

442
    Parameters
443
    ----------
444
    sys : NonlinearIOSystem or callable(t, x, ...)
445
        I/O system or function used to generate phase plane data. If a
446
        function is given, the remaining arguments are drawn from the
447
        `params` keyword.
448
    pointdata : list or 2D array
449
        List of the form [xmin, xmax, ymin, ymax] describing the
450
        boundaries of the phase plot or an array of shape (N, 2)
451
        giving points of at which to plot the vector field.
452
    gridtype : str, optional
453
        The type of grid to use for generating initial conditions:
454
        'meshgrid' (default) generates a mesh of initial conditions within
455
        the specified boundaries, 'boxgrid' generates initial conditions
456
        along the edges of the boundary, 'circlegrid' generates a circle of
457
        initial conditions around each point in point data.
458
    gridspec : list, optional
459
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
460
        size of the grid in the x and y axes on which to generate points.
461
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
462
        specifying the radius and number of points around each point in the
463
        `pointdata` array.
464
    params : dict or list, optional
465
        Parameters to pass to system. For an I/O system, `params` should be
466
        a dict of parameters and values. For a callable, `params` should be
467
        dict with key 'args' and value given by a tuple (passed to callable).
468
    color : str
469
        Plot the equilibrium points in the given color.
470
    ax : matplotlib.axes.Axes
471
        Use the given axes for the plot, otherwise use the current axes.
472

473
    Returns
474
    -------
475
    out : list of Line2D objects
476

477
    """
478
    # Process keywords
479
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
480

481
    # Get system parameters
482
    params = kwargs.pop('params', None)
9✔
483

484
    # Create system from callable, if needed
485
    sys = _create_system(sys, params)
9✔
486

487
    # Create axis if needed
488
    if ax is None:
9✔
489
        ax = plt.gca()
9✔
490

491
    # Set the axis limits
492
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
493

494
    # Determine the points on which to generate the vector field
495
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
496
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
497

498
    # Make sure all keyword arguments were processed
499
    if check_kwargs and kwargs:
9✔
500
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
501

502
    # Search for equilibrium points
503
    equilpts = _find_equilpts(sys, points, params=params)
9✔
504

505
    # Plot the equilibrium points
506
    out = []
9✔
507
    for xeq in equilpts:
9✔
508
        with plt.rc_context(rcParams):
9✔
509
            out.append(
9✔
510
                ax.plot(xeq[0], xeq[1], marker='o', color=color))
511
    return out
9✔
512

513

514
def separatrices(
9✔
515
        sys, pointdata, timedata=None, gridspec=None, ax=None,
516
        check_kwargs=True, suppress_warnings=False, **kwargs):
517
    """Plot separatrices in the phase plane.
518

519
    This function plots separatrices for a two-dimensional state space
520
    system.
521

522
    Parameters
523
    ----------
524
    sys : NonlinearIOSystem or callable(t, x, ...)
525
        I/O system or function used to generate phase plane data. If a
526
        function is given, the remaining arguments are drawn from the
527
        `params` keyword.
528
    pointdata : list or 2D array
529
        List of the form [xmin, xmax, ymin, ymax] describing the
530
        boundaries of the phase plot or an array of shape (N, 2)
531
        giving points of at which to plot the vector field.
532
    timedata : int or list of int
533
        Time to simulate each streamline.  If a list is given, a different
534
        time can be used for each initial condition in `pointdata`.
535
    gridtype : str, optional
536
        The type of grid to use for generating initial conditions:
537
        'meshgrid' (default) generates a mesh of initial conditions within
538
        the specified boundaries, 'boxgrid' generates initial conditions
539
        along the edges of the boundary, 'circlegrid' generates a circle of
540
        initial conditions around each point in point data.
541
    gridspec : list, optional
542
        If the gridtype is 'meshgrid' and 'boxgrid', `gridspec` gives the
543
        size of the grid in the x and y axes on which to generate points.
544
        If gridtype is 'circlegrid', then `gridspec` is a 2-tuple
545
        specifying the radius and number of points around each point in the
546
        `pointdata` array.
547
    params : dict or list, optional
548
        Parameters to pass to system. For an I/O system, `params` should be
549
        a dict of parameters and values. For a callable, `params` should be
550
        dict with key 'args' and value given by a tuple (passed to callable).
551
    color : matplotlib color spec, optional
552
        Plot the separatrics in the given color.  If a single color
553
        specification is given, this is used for both stable and unstable
554
        separatrices.  If a tuple is given, the first element is used as
555
        the color specification for stable separatrices and the second
556
        elmeent for unstable separatrices.
557
    ax : matplotlib.axes.Axes
558
        Use the given axes for the plot, otherwise use the current axes.
559

560
    Returns
561
    -------
562
    out : list of Line2D objects
563

564
    Other parameters
565
    ----------------
566
    suppress_warnings : bool, optional
567
        If set to `True`, suppress warning messages in generating trajectories.
568

569
    """
570
    # Process keywords
571
    rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
9✔
572

573
    # Get system parameters
574
    params = kwargs.pop('params', None)
9✔
575

576
    # Create system from callable, if needed
577
    sys = _create_system(sys, params)
9✔
578

579
    # Parse the arrows keyword
580
    arrow_pos, arrow_style = _parse_arrow_keywords(kwargs)
9✔
581

582
    # Determine the initial states to use in searching for equilibrium points
583
    gridspec = [5, 5] if gridspec is None else gridspec
9✔
584
    points, _ = _make_points(pointdata, gridspec, 'meshgrid')
9✔
585

586
    # Find the equilibrium points
587
    equilpts = _find_equilpts(sys, points, params=params)
9✔
588
    radius = config._get_param('phaseplot', 'separatrices_radius')
9✔
589

590
    # Create axis if needed
591
    if ax is None:
9✔
592
        ax = plt.gca()
9✔
593

594
    # Set the axis limits
595
    xlim, ylim, maxlim = _set_axis_limits(ax, pointdata)
9✔
596

597
    # Figure out the color to use for stable, unstable subspaces
598
    color = _get_color(kwargs)
9✔
599
    match color:
9✔
600
        case None:
9✔
601
            stable_color = 'r'
9✔
602
            unstable_color = 'b'
9✔
603
        case (stable_color, unstable_color) | [stable_color, unstable_color]:
9✔
604
            pass
9✔
605
        case single_color:
9✔
606
            stable_color = unstable_color = color
9✔
607

608
    # Make sure all keyword arguments were processed
609
    if check_kwargs and kwargs:
9✔
610
        raise TypeError("unrecognized keywords: ", str(kwargs))
9✔
611

612
    # Create a "reverse time" system to use for simulation
613
    revsys = NonlinearIOSystem(
9✔
614
        lambda t, x, u, params: -np.array(sys.updfcn(t, x, u, params)),
615
        sys.outfcn, states=sys.nstates, inputs=sys.ninputs,
616
        outputs=sys.noutputs, params=sys.params)
617

618
    # Plot separatrices by flowing backwards in time along eigenspaces
619
    out = []
9✔
620
    for i, xeq in enumerate(equilpts):
9✔
621
        # Plot the equilibrium points
622
        with plt.rc_context(rcParams):
9✔
623
            out.append(
9✔
624
                ax.plot(xeq[0], xeq[1], marker='o', color='k'))
625

626
        # Figure out the linearization and eigenvectors
627
        evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
9✔
628

629
        # See if we have real eigenvalues (=> evecs are meaningful)
630
        if evals[0].imag > 0:
9✔
631
            continue
9✔
632

633
        # Create default list of time points
634
        if timedata is not None:
9✔
635
            timepts = _make_timepts(timedata, i)
9✔
636

637
        # Generate the traces
638
        for j, dir in enumerate(evecs.T):
9✔
639
            # Figure out time vector if not yet computed
640
            if timedata is None:
9✔
641
                timescale = math.log(maxlim / radius) / abs(evals[j].real)
9✔
642
                timepts = np.linspace(0, timescale)
9✔
643

644
            # Run the trajectory starting in eigenvector directions
645
            for eps in [-radius, radius]:
9✔
646
                x0 = xeq + dir * eps
9✔
647
                if evals[j].real < 0:
9✔
648
                    traj = _create_trajectory(
9✔
649
                        sys, revsys, timepts, x0, params, 'reverse',
650
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
651
                        suppress_warnings=suppress_warnings)
652
                    color = stable_color
9✔
653
                    linestyle = '--'
9✔
654
                elif evals[j].real > 0:
9✔
655
                    traj = _create_trajectory(
9✔
656
                        sys, revsys, timepts, x0, params, 'forward',
657
                        gridtype='boxgrid', xlim=xlim, ylim=ylim,
658
                        suppress_warnings=suppress_warnings)
659
                    color = unstable_color
9✔
660
                    linestyle = '-'
9✔
661

662
                # Plot the trajectory (if there is one)
663
                if traj.shape[1] > 1:
9✔
664
                    with plt.rc_context(rcParams):
9✔
665
                        out.append(ax.plot(
9✔
666
                            traj[0], traj[1], color=color, linestyle=linestyle))
667

668
                    # Add arrows to the lines at specified intervals
669
                    with plt.rc_context(rcParams):
9✔
670
                        _add_arrows_to_line2D(
9✔
671
                            ax, out[-1][0], arrow_pos, arrowstyle=arrow_style,
672
                            dir=1)
673
    return out
9✔
674

675

676
#
677
# User accessible utility functions
678
#
679

680
# Utility function to generate boxgrid (in the form needed here)
681
def boxgrid(xvals, yvals):
9✔
682
    """Generate list of points along the edge of box.
683

684
    points = boxgrid(xvals, yvals) generates a list of points that
685
    corresponds to a grid given by the cross product of the x and y values.
686

687
    Parameters
688
    ----------
689
    xvals, yvals: 1D array-like
690
        Array of points defining the points on the lower and left edges of
691
        the box.
692

693
    Returns
694
    -------
695
    grid: 2D array
696
        Array with shape (p, 2) defining the points along the edges of the
697
        box, where p is the number of points around the edge.
698

699
    """
700
    return np.array(
9✔
701
        [(x, yvals[0]) for x in xvals[:-1]] +           # lower edge
702
        [(xvals[-1], y) for y in yvals[:-1]] +          # right edge
703
        [(x, yvals[-1]) for x in xvals[:0:-1]] +        # upper edge
704
        [(xvals[0], y) for y in yvals[:0:-1]]           # left edge
705
    )
706

707

708
# Utility function to generate meshgrid (in the form needed here)
709
# TODO: add examples of using grid functions directly
710
def meshgrid(xvals, yvals):
9✔
711
    """Generate list of points forming a mesh.
712

713
    points = meshgrid(xvals, yvals) generates a list of points that
714
    corresponds to a grid given by the cross product of the x and y values.
715

716
    Parameters
717
    ----------
718
    xvals, yvals: 1D array-like
719
        Array of points defining the points on the lower and left edges of
720
        the box.
721

722
    Returns
723
    -------
724
    grid: 2D array
725
        Array of points with shape (n * m, 2) defining the mesh
726

727
    """
728
    xvals, yvals = np.meshgrid(xvals, yvals)
9✔
729
    grid = np.zeros((xvals.shape[0] * xvals.shape[1], 2))
9✔
730
    grid[:, 0] = xvals.reshape(-1)
9✔
731
    grid[:, 1] = yvals.reshape(-1)
9✔
732

733
    return grid
9✔
734

735

736
# Utility function to generate circular grid
737
def circlegrid(centers, radius, num):
9✔
738
    """Generate list of points around a circle.
739

740
    points = circlegrid(centers, radius, num) generates a list of points
741
    that form a circle around a list of centers.
742

743
    Parameters
744
    ----------
745
    centers : 2D array-like
746
        Array of points with shape (p, 2) defining centers of the circles.
747
    radius : float
748
        Radius of the points to be generated around each center.
749
    num : int
750
        Number of points to generate around the circle.
751

752
    Returns
753
    -------
754
    grid: 2D array
755
        Array of points with shape (p * num, 2) defining the circles.
756

757
    """
758
    centers = np.atleast_2d(np.array(centers))
9✔
759
    grid = np.zeros((centers.shape[0] * num, 2))
9✔
760
    for i, center in enumerate(centers):
9✔
761
        grid[i * num: (i + 1) * num, :] = center + np.array([
9✔
762
            [radius * math.cos(theta), radius * math.sin(theta)] for
763
            theta in np.linspace(0, 2 * math.pi, num, endpoint=False)])
764
    return grid
9✔
765

766
#
767
# Internal utility functions
768
#
769

770
# Create a system from a callable
771
def _create_system(sys, params):
9✔
772
    if isinstance(sys, NonlinearIOSystem):
9✔
773
        if sys.nstates != 2:
9✔
774
            raise ValueError("system must be planar")
9✔
775
        return sys
9✔
776

777
    # Make sure that if params is present, it has 'args' key
778
    if params and not params.get('args', None):
9✔
779
        raise ValueError("params must be dict with key 'args'")
9✔
780

781
    _update = lambda t, x, u, params: sys(t, x, *params.get('args', ()))
9✔
782
    _output = lambda t, x, u, params: np.array([])
9✔
783
    return NonlinearIOSystem(
9✔
784
        _update, _output, states=2, inputs=0, outputs=0, name="_callable")
785

786
# Set axis limits for the plot
787
def _set_axis_limits(ax, pointdata):
9✔
788
    # Get the current axis limits
789
    if ax.lines:
9✔
790
        xlim, ylim = ax.get_xlim(), ax.get_ylim()
9✔
791
    else:
792
        # Nothing on the plot => always use new limits
793
        xlim, ylim = [np.inf, -np.inf], [np.inf, -np.inf]
9✔
794

795
    # Short utility function for updating axis limits
796
    def _update_limits(cur, new):
9✔
797
        return [min(cur[0], np.min(new)), max(cur[1], np.max(new))]
9✔
798

799
    # If we were passed a box, use that to update the limits
800
    if isinstance(pointdata, list) and len(pointdata) == 4:
9✔
801
        xlim = _update_limits(xlim, [pointdata[0], pointdata[1]])
9✔
802
        ylim = _update_limits(ylim, [pointdata[2], pointdata[3]])
9✔
803

804
    elif isinstance(pointdata, np.ndarray):
9✔
805
        pointdata = np.atleast_2d(pointdata)
9✔
806
        xlim = _update_limits(
9✔
807
            xlim, [np.min(pointdata[:, 0]), np.max(pointdata[:, 0])])
808
        ylim = _update_limits(
9✔
809
            ylim, [np.min(pointdata[:, 1]), np.max(pointdata[:, 1])])
810

811
    # Keep track of the largest dimension on the plot
812
    maxlim = max(xlim[1] - xlim[0], ylim[1] - ylim[0])
9✔
813

814
    # Set the new limits
815
    ax.autoscale(enable=True, axis='x', tight=True)
9✔
816
    ax.autoscale(enable=True, axis='y', tight=True)
9✔
817
    ax.set_xlim(xlim)
9✔
818
    ax.set_ylim(ylim)
9✔
819

820
    return xlim, ylim, maxlim
9✔
821

822

823
# Find equilibrium points
824
def _find_equilpts(sys, points, params=None):
9✔
825
    equilpts = []
9✔
826
    for i, x0 in enumerate(points):
9✔
827
        # Look for an equilibrium point near this point
828
        xeq, ueq = find_eqpt(sys, x0, 0, params=params)
9✔
829

830
        if xeq is None:
9✔
831
            continue            # didn't find anything
9✔
832

833
        # See if we have already found this point
834
        seen = False
9✔
835
        for x in equilpts:
9✔
836
            if np.allclose(np.array(x), xeq):
9✔
837
                seen = True
9✔
838
        if seen:
9✔
839
            continue
9✔
840

841
        # Save a new point
842
        equilpts += [xeq.tolist()]
9✔
843

844
    return equilpts
9✔
845

846

847
def _make_points(pointdata, gridspec, gridtype):
9✔
848
    # Check to see what type of data we got
849
    if isinstance(pointdata, np.ndarray) and gridtype is None:
9✔
850
        pointdata = np.atleast_2d(pointdata)
9✔
851
        if pointdata.shape[1] == 2:
9✔
852
            # Given a list of points => no action required
853
            return pointdata, None
9✔
854

855
    # Utility function to parse (and check) input arguments
856
    def _parse_args(defsize):
9✔
857
        if gridspec is None:
9✔
858
            return defsize
9✔
859

860
        elif not isinstance(gridspec, (list, tuple)) or \
9✔
861
             len(gridspec) != len(defsize):
862
            raise ValueError("invalid grid specification")
9✔
863

864
        return gridspec
9✔
865

866
    # Generate points based on grid type
867
    match gridtype:
9✔
868
        case 'boxgrid' | None:
9✔
869
            gridspec = _parse_args([6, 4])
9✔
870
            points = boxgrid(
9✔
871
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
872
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
873

874
        case 'meshgrid':
9✔
875
            gridspec = _parse_args([9, 6])
9✔
876
            points = meshgrid(
9✔
877
                np.linspace(pointdata[0], pointdata[1], gridspec[0]),
878
                np.linspace(pointdata[2], pointdata[3], gridspec[1]))
879

880
        case 'circlegrid':
9✔
881
            gridspec = _parse_args((0.5, 10))
9✔
882
            if isinstance(pointdata, np.ndarray):
9✔
883
                # Create circles around each point
884
                points = circlegrid(pointdata, gridspec[0], gridspec[1])
9✔
885
            else:
886
                # Create circle around center of the plot
887
                points = circlegrid(
9✔
888
                    np.array(
889
                        [(pointdata[0] + pointdata[1]) / 2,
890
                         (pointdata[0] + pointdata[1]) / 2]),
891
                    gridspec[0], gridspec[1])
892

893
        case _:
9✔
894
            raise ValueError(f"unknown grid type '{gridtype}'")
9✔
895

896
    return points, gridspec
9✔
897

898

899
def _parse_arrow_keywords(kwargs):
9✔
900
    # Get values for params (and pop from list to allow keyword use in plot)
901
    # TODO: turn this into a utility function (shared with nyquist_plot?)
902
    arrows = config._get_param(
9✔
903
        'phaseplot', 'arrows', kwargs, None, pop=True)
904
    arrow_size = config._get_param(
9✔
905
        'phaseplot', 'arrow_size', kwargs, None, pop=True)
906
    arrow_style = config._get_param('phaseplot', 'arrow_style', kwargs, None)
9✔
907

908
    # Parse the arrows keyword
909
    if not arrows:
9✔
910
        arrow_pos = []
×
911
    elif isinstance(arrows, int):
9✔
912
        N = arrows
9✔
913
        # Space arrows out, starting midway along each "region"
914
        arrow_pos = np.linspace(0.5/N, 1 + 0.5/N, N, endpoint=False)
9✔
915
    elif isinstance(arrows, (list, np.ndarray)):
×
916
        arrow_pos = np.sort(np.atleast_1d(arrows))
×
917
    else:
918
        raise ValueError("unknown or unsupported arrow location")
×
919

920
    # Set the arrow style
921
    if arrow_style is None:
9✔
922
        arrow_style = mpl.patches.ArrowStyle(
9✔
923
            'simple', head_width=int(2 * arrow_size / 3),
924
            head_length=arrow_size)
925

926
    return arrow_pos, arrow_style
9✔
927

928

929
# TODO: move to ctrlplot?
930
def _create_trajectory(
9✔
931
        sys, revsys, timepts, X0, params, dir, suppress_warnings=False,
932
        gridtype=None, gridspec=None, xlim=None, ylim=None):
933
    # Comput ethe forward trajectory
934
    if dir == 'forward' or dir == 'both':
9✔
935
        fwdresp = input_output_response(
9✔
936
            sys, timepts, X0=X0, params=params, ignore_errors=True)
937
        if not fwdresp.success and not suppress_warnings:
9✔
938
            warnings.warn(f"{X0=}, {fwdresp.message}")
9✔
939

940
    # Compute the reverse trajectory
941
    if dir == 'reverse' or dir == 'both':
9✔
942
        revresp = input_output_response(
9✔
943
            revsys, timepts, X0=X0, params=params, ignore_errors=True)
944
        if not revresp.success and not suppress_warnings:
9✔
945
            warnings.warn(f"{X0=}, {revresp.message}")
×
946

947
    # Create the trace to plot
948
    if dir == 'forward':
9✔
949
        traj = fwdresp.states
9✔
950
    elif dir == 'reverse':
9✔
951
        traj = revresp.states[:, ::-1]
9✔
952
    elif dir == 'both':
9✔
953
        traj = np.hstack([revresp.states[:, :1:-1], fwdresp.states])
9✔
954

955
    # Remove points outside the window (keep first point beyond boundary)
956
    inrange = np.asarray(
9✔
957
        (traj[0] >= xlim[0]) & (traj[0] <= xlim[1]) &
958
        (traj[1] >= ylim[0]) & (traj[1] <= ylim[1]))
959
    inrange[:-1] = inrange[:-1] | inrange[1:]   # keep if next point in range
9✔
960
    inrange[1:] = inrange[1:] | inrange[:-1]    # keep if prev point in range
9✔
961

962
    return traj[:, inrange]
9✔
963

964

965
def _make_timepts(timepts, i):
9✔
966
    if timepts is None:
9✔
967
        return np.linspace(0, 1)
9✔
968
    elif isinstance(timepts, (int, float)):
9✔
969
        return np.linspace(0, timepts)
9✔
970
    elif timepts.ndim == 2:
×
971
        return timepts[i]
×
972
    return timepts
×
973

974

975
#
976
# Legacy phase plot function
977
#
978
# Author: Richard Murray
979
# Date: 24 July 2011, converted from MATLAB version (2002); based on
980
# a version by Kristi Morgansen
981
#
982
def phase_plot(odefun, X=None, Y=None, scale=1, X0=None, T=None,
9✔
983
               lingrid=None, lintime=None, logtime=None, timepts=None,
984
               parms=None, params=(), tfirst=False, verbose=True):
985

986
    """(legacy) Phase plot for 2D dynamical systems.
987

988
    Produces a vector field or stream line plot for a planar system.  This
989
    function has been replaced by the :func:`~control.phase_plane_map` and
990
    :func:`~control.phase_plane_plot` functions.
991

992
    Call signatures:
993
      phase_plot(func, X, Y, ...) - display vector field on meshgrid
994
      phase_plot(func, X, Y, scale, ...) - scale arrows
995
      phase_plot(func. X0=(...), T=Tmax, ...) - display stream lines
996
      phase_plot(func, X, Y, X0=[...], T=Tmax, ...) - plot both
997
      phase_plot(func, X0=[...], T=Tmax, lingrid=N, ...) - plot both
998
      phase_plot(func, X0=[...], lintime=N, ...) - stream lines with arrows
999

1000
    Parameters
1001
    ----------
1002
    func : callable(x, t, ...)
1003
        Computes the time derivative of y (compatible with odeint).  The
1004
        function should be the same for as used for :mod:`scipy.integrate`.
1005
        Namely, it should be a function of the form dxdt = F(t, x) that
1006
        accepts a state x of dimension 2 and returns a derivative dx/dt of
1007
        dimension 2.
1008
    X, Y: 3-element sequences, optional, as [start, stop, npts]
1009
        Two 3-element sequences specifying x and y coordinates of a
1010
        grid.  These arguments are passed to linspace and meshgrid to
1011
        generate the points at which the vector field is plotted.  If
1012
        absent (or None), the vector field is not plotted.
1013
    scale: float, optional
1014
        Scale size of arrows; default = 1
1015
    X0: ndarray of initial conditions, optional
1016
        List of initial conditions from which streamlines are plotted.
1017
        Each initial condition should be a pair of numbers.
1018
    T: array-like or number, optional
1019
        Length of time to run simulations that generate streamlines.
1020
        If a single number, the same simulation time is used for all
1021
        initial conditions.  Otherwise, should be a list of length
1022
        len(X0) that gives the simulation time for each initial
1023
        condition.  Default value = 50.
1024
    lingrid : integer or 2-tuple of integers, optional
1025
        Argument is either N or (N, M).  If X0 is given and X, Y are missing,
1026
        a grid of arrows is produced using the limits of the initial
1027
        conditions, with N grid points in each dimension or N grid points in x
1028
        and M grid points in y.
1029
    lintime : integer or tuple (integer, float), optional
1030
        If a single integer N is given, draw N arrows using equally space time
1031
        points.  If a tuple (N, lambda) is given, draw N arrows using
1032
        exponential time constant lambda
1033
    timepts : array-like, optional
1034
        Draw arrows at the given list times [t1, t2, ...]
1035
    tfirst : bool, optional
1036
        If True, call `func` with signature `func(t, x, ...)`.
1037
    params: tuple, optional
1038
        List of parameters to pass to vector field: `func(x, t, *params)`
1039

1040
    See also
1041
    --------
1042
    box_grid : construct box-shaped grid of initial conditions
1043

1044
    """
1045
    # Generate a deprecation warning
1046
    warnings.warn(
9✔
1047
        "phase_plot is deprecated; use phase_plot_plot instead",
1048
        FutureWarning)
1049

1050
    #
1051
    # Figure out ranges for phase plot (argument processing)
1052
    #
1053
    #! TODO: need to add error checking to arguments
1054
    #! TODO: think through proper action if multiple options are given
1055
    #
1056
    autoFlag = False
9✔
1057
    logtimeFlag = False
9✔
1058
    timeptsFlag = False
9✔
1059
    Narrows = 0
9✔
1060

1061
    # Get parameters to pass to function
1062
    if parms:
9✔
1063
        warnings.warn(
9✔
1064
            f"keyword 'parms' is deprecated; use 'params'", FutureWarning)
1065
        if params:
9✔
1066
            raise ControlArgument(f"duplicate keywords 'parms' and 'params'")
×
1067
        else:
1068
            params = parms
9✔
1069

1070
    if lingrid is not None:
9✔
1071
        autoFlag = True
9✔
1072
        Narrows = lingrid
9✔
1073
        if (verbose):
9✔
1074
            print('Using auto arrows\n')
×
1075

1076
    elif logtime is not None:
9✔
1077
        logtimeFlag = True
9✔
1078
        Narrows = logtime[0]
9✔
1079
        timefactor = logtime[1]
9✔
1080
        if (verbose):
9✔
1081
            print('Using logtime arrows\n')
×
1082

1083
    elif timepts is not None:
9✔
1084
        timeptsFlag = True
9✔
1085
        Narrows = len(timepts)
9✔
1086

1087
    # Figure out the set of points for the quiver plot
1088
    #! TODO: Add sanity checks
1089
    elif X is not None and Y is not None:
9✔
1090
        x1, x2 = np.meshgrid(
9✔
1091
            np.linspace(X[0], X[1], X[2]),
1092
            np.linspace(Y[0], Y[1], Y[2]))
1093
        Narrows = len(x1)
9✔
1094

1095
    else:
1096
        # If we weren't given any grid points, don't plot arrows
1097
        Narrows = 0
9✔
1098

1099
    if not autoFlag and not logtimeFlag and not timeptsFlag and Narrows > 0:
9✔
1100
        # Now calculate the vector field at those points
1101
        (nr,nc) = x1.shape
9✔
1102
        dx = np.empty((nr, nc, 2))
9✔
1103
        for i in range(nr):
9✔
1104
            for j in range(nc):
9✔
1105
                if tfirst:
9✔
1106
                    dx[i, j, :] = np.squeeze(
×
1107
                        odefun(0, [x1[i,j], x2[i,j]], *params))
1108
                else:
1109
                    dx[i, j, :] = np.squeeze(
9✔
1110
                        odefun([x1[i,j], x2[i,j]], 0, *params))
1111

1112
        # Plot the quiver plot
1113
        #! TODO: figure out arguments to make arrows show up correctly
1114
        if scale is None:
9✔
1115
            plt.quiver(x1, x2, dx[:,:,1], dx[:,:,2], angles='xy')
×
1116
        elif (scale != 0):
9✔
1117
            #! TODO: optimize parameters for arrows
1118
            #! TODO: figure out arguments to make arrows show up correctly
1119
            xy = plt.quiver(x1, x2, dx[:,:,0]*np.abs(scale),
9✔
1120
                            dx[:,:,1]*np.abs(scale), angles='xy')
1121
            # set(xy, 'LineWidth', PP_arrow_linewidth, 'Color', 'b')
1122

1123
        #! TODO: Tweak the shape of the plot
1124
        # a=gca; set(a,'DataAspectRatio',[1,1,1])
1125
        # set(a,'XLim',X(1:2)); set(a,'YLim',Y(1:2))
1126
        plt.xlabel('x1'); plt.ylabel('x2')
9✔
1127

1128
    # See if we should also generate the streamlines
1129
    if X0 is None or len(X0) == 0:
9✔
1130
        return
9✔
1131

1132
    # Convert initial conditions to a numpy array
1133
    X0 = np.array(X0)
9✔
1134
    (nr, nc) = np.shape(X0)
9✔
1135

1136
    # Generate some empty matrices to keep arrow information
1137
    x1 = np.empty((nr, Narrows))
9✔
1138
    x2 = np.empty((nr, Narrows))
9✔
1139
    dx = np.empty((nr, Narrows, 2))
9✔
1140

1141
    # See if we were passed a simulation time
1142
    if T is None:
9✔
1143
        T = 50
9✔
1144

1145
    # Parse the time we were passed
1146
    TSPAN = T
9✔
1147
    if isinstance(T, (int, float)):
9✔
1148
        TSPAN = np.linspace(0, T, 100)
9✔
1149

1150
    # Figure out the limits for the plot
1151
    if scale is None:
9✔
1152
        # Assume that the current axis are set as we want them
1153
        alim = plt.axis()
×
1154
        xmin = alim[0]; xmax = alim[1]
×
1155
        ymin = alim[2]; ymax = alim[3]
×
1156
    else:
1157
        # Use the maximum extent of all trajectories
1158
        xmin = np.min(X0[:,0]); xmax = np.max(X0[:,0])
9✔
1159
        ymin = np.min(X0[:,1]); ymax = np.max(X0[:,1])
9✔
1160

1161
    # Generate the streamlines for each initial condition
1162
    for i in range(nr):
9✔
1163
        state = odeint(odefun, X0[i], TSPAN, args=params, tfirst=tfirst)
9✔
1164
        time = TSPAN
9✔
1165

1166
        plt.plot(state[:,0], state[:,1])
9✔
1167
        #! TODO: add back in colors for stream lines
1168
        # PP_stream_color(np.mod(i-1, len(PP_stream_color))+1))
1169
        # set(h[i], 'LineWidth', PP_stream_linewidth)
1170

1171
        # Plot arrows if quiver parameters were 'auto'
1172
        if autoFlag or logtimeFlag or timeptsFlag:
9✔
1173
            # Compute the locations of the arrows
1174
            #! TODO: check this logic to make sure it works in python
1175
            for j in range(Narrows):
9✔
1176

1177
                # Figure out starting index; headless arrows start at 0
1178
                k = -1 if scale is None else 0
9✔
1179

1180
                # Figure out what time index to use for the next point
1181
                if autoFlag:
9✔
1182
                    # Use a linear scaling based on ODE time vector
1183
                    tind = np.floor((len(time)/Narrows) * (j-k)) + k
×
1184
                elif logtimeFlag:
9✔
1185
                    # Use an exponential time vector
1186
                    # MATLAB: tind = find(time < (j-k) / lambda, 1, 'last')
1187
                    tarr = _find(time < (j-k) / timefactor)
9✔
1188
                    tind = tarr[-1] if len(tarr) else 0
9✔
1189
                elif timeptsFlag:
9✔
1190
                    # Use specified time points
1191
                    # MATLAB: tind = find(time < Y[j], 1, 'last')
1192
                    tarr = _find(time < timepts[j])
9✔
1193
                    tind = tarr[-1] if len(tarr) else 0
9✔
1194

1195
                # For tailless arrows, skip the first point
1196
                if tind == 0 and scale is None:
9✔
1197
                    continue
×
1198

1199
                # Figure out the arrow at this point on the curve
1200
                x1[i,j] = state[tind, 0]
9✔
1201
                x2[i,j] = state[tind, 1]
9✔
1202

1203
                # Skip arrows outside of initial condition box
1204
                if (scale is not None or
9✔
1205
                     (x1[i,j] <= xmax and x1[i,j] >= xmin and
1206
                      x2[i,j] <= ymax and x2[i,j] >= ymin)):
1207
                    if tfirst:
9✔
1208
                        pass
×
1209
                        v = odefun(0, [x1[i,j], x2[i,j]], *params)
×
1210
                    else:
1211
                        v = odefun([x1[i,j], x2[i,j]], 0, *params)
9✔
1212
                    dx[i, j, 0] = v[0]; dx[i, j, 1] = v[1]
9✔
1213
                else:
1214
                    dx[i, j, 0] = 0; dx[i, j, 1] = 0
×
1215

1216
    # Set the plot shape before plotting arrows to avoid warping
1217
    # a=gca
1218
    # if (scale != None):
1219
    #     set(a,'DataAspectRatio', [1,1,1])
1220
    # if (xmin != xmax and ymin != ymax):
1221
    #     plt.axis([xmin, xmax, ymin, ymax])
1222
    # set(a, 'Box', 'on')
1223

1224
    # Plot arrows on the streamlines
1225
    if scale is None and Narrows > 0:
9✔
1226
        # Use a tailless arrow
1227
        #! TODO: figure out arguments to make arrows show up correctly
1228
        plt.quiver(x1, x2, dx[:,:,0], dx[:,:,1], angles='xy')
×
1229
    elif scale != 0 and Narrows > 0:
9✔
1230
        #! TODO: figure out arguments to make arrows show up correctly
1231
        xy = plt.quiver(x1, x2, dx[:,:,0]*abs(scale), dx[:,:,1]*abs(scale),
9✔
1232
                        angles='xy')
1233
        # set(xy, 'LineWidth', PP_arrow_linewidth)
1234
        # set(xy, 'AutoScale', 'off')
1235
        # set(xy, 'AutoScaleFactor', 0)
1236

1237
    if scale < 0:
9✔
1238
        bp = plt.plot(x1, x2, 'b.');        # add dots at base
×
1239
        # set(bp, 'MarkerSize', PP_arrow_markersize)
1240

1241

1242
# Utility function for generating initial conditions around a box
1243
def box_grid(xlimp, ylimp):
9✔
1244
    """box_grid   generate list of points on edge of box
1245

1246
    list = box_grid([xmin xmax xnum], [ymin ymax ynum]) generates a
1247
    list of points that correspond to a uniform grid at the end of the
1248
    box defined by the corners [xmin ymin] and [xmax ymax].
1249
    """
1250

1251
    # Generate a deprecation warning
1252
    warnings.warn(
×
1253
        "box_grid is deprecated; use phaseplot.boxgrid instead",
1254
        FutureWarning)
1255

1256
    return boxgrid(
×
1257
        np.linspace(xlimp[0], xlimp[1], xlimp[2]),
1258
        np.linspace(ylimp[0], ylimp[1], ylimp[2]))
1259

1260

1261
# TODO: rename to something more useful (or remove??)
1262
def _find(condition):
9✔
1263
    """Returns indices where ravel(a) is true.
1264
    Private implementation of deprecated matplotlib.mlab.find
1265
    """
1266
    return np.nonzero(np.ravel(condition))[0]
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc