• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenMDAO / dymos / 13169126399

06 Feb 2025 12:21AM UTC coverage: 92.63% (-1.2%) from 93.815%
13169126399

Pull #1142

github

web-flow
Merge 820bb8cfd into cd45db95f
Pull Request #1142: Some cleanup of the Birkhoff transcription and tests.

212 of 241 new or added lines in 15 files covered. (87.97%)

412 existing lines in 16 files now uncovered.

32852 of 35466 relevant lines covered (92.63%)

5.64 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.19
/dymos/utils/testing_utils.py
1
import io
7✔
2

3
from packaging.version import Version
7✔
4

5
import numpy as np
7✔
6

7
from scipy.interpolate import Akima1DInterpolator
7✔
8

9
import openmdao.api as om
7✔
10
import openmdao.utils.assert_utils as _om_assert_utils
7✔
11
from openmdao import __version__ as openmdao_version
7✔
12

13

14
def assert_check_partials(data, atol=1.0E-6, rtol=1.0E-6):
7✔
15
    """
16
    Wrapper around OpenMDAO's assert_check_partials with a dymos-specific message.
17

18
    Calls OpenMDAO's assert_check_partials but verifies that the dictionary of assertion data is
19
    not empty due to dymos.options['include_check_partials'] being False.
20

21
    Parameters
22
    ----------
23
    data : dict of dicts of dicts
24
            First key:
25
                is the component name;
26
            Second key:
27
                is the (output, input) tuple of strings;
28
            Third key:
29
                is one of ['rel error', 'abs error', 'magnitude', 'J_fd', 'J_fwd', 'J_rev'];
30

31
            For 'rel error', 'abs error', 'magnitude' the value is: A tuple containing norms for
32
                forward - fd, adjoint - fd, forward - adjoint.
33
            For 'J_fd', 'J_fwd', 'J_rev' the value is: A numpy array representing the computed
34
                Jacobian for the three different methods of computation.
35
    atol : float
36
        Absolute error. Default is 1e-6.
37
    rtol : float
38
        Relative error. Default is 1e-6.
39
    """
40
    assert len(data) >= 1, "No check partials data found.  Is " \
7✔
41
                           "dymos.options['include_check_partials'] set to True?"
42
    _om_assert_utils.assert_check_partials(data, atol, rtol)
7✔
43

44

45
def assert_cases_equal(case1, case2, tol=1.0E-12, require_same_vars=True):
7✔
46
    """
47
    Raise AssertionError if the data in two OpenMDAO Cases is different.
48

49
    Parameters
50
    ----------
51
    case1 : om.Case
52
        The first OpenMDAO Case for comparison.
53
    case2 : om.Case
54
        The second OpenMDAO Case for comparison.
55
    tol : float
56
        The absolute value of the allowable difference in values between two variables.
57
    require_same_vars : bool
58
        If True, require that the two files contain the same set of variables.
59

60
    Raises
61
    ------
62
    AssertionError
63
        Raised in the following cases:  If require_same_vars is True, then AssertionError is raised
64
        if the two cases contain different variables.  Otherwise, this error is raised if case1
65
        and case2 contain the same variable but the variable has a different size/shape in the two
66
        cases, or if the variables have the same shape but different values (as given by tol).
67
    """
68
    _case1 = case1.model if isinstance(case1, om.Problem) else case1
7✔
69
    _case2 = case2.model if isinstance(case2, om.Problem) else case2
7✔
70

71
    case1_vars = {t[1]['prom_name']: t[1] for t in
7✔
72
                  _case1.list_inputs(val=True, units=True, prom_name=True, out_stream=None)}
73
    case1_vars.update({t[1]['prom_name']: t[1] for t in
7✔
74
                       _case1.list_outputs(val=True, units=True, prom_name=True, out_stream=None)})
75

76
    case2_vars = {t[1]['prom_name']: t[1] for t in
7✔
77
                  _case2.list_inputs(val=True, units=True, prom_name=True, out_stream=None)}
78
    case2_vars.update({t[1]['prom_name']: t[1] for t in
7✔
79
                       _case2.list_outputs(val=True, units=True, prom_name=True, out_stream=None)})
80

81
    # Warn if a and b don't contain the same sets of variables
82
    diff_err_msg = ''
7✔
83
    if require_same_vars:
7✔
84
        case1_minus_case2 = set(case1_vars.keys()) - set(case2_vars.keys())
7✔
85
        case2_minus_case1 = set(case2_vars.keys()) - set(case1_vars.keys())
7✔
86
        if case1_minus_case2 or case2_minus_case1:
7✔
87
            diff_err_msg = '\nrequire_same_vars=True but cases contain different variables.'
7✔
88
        if case1_minus_case2:
7✔
89
            diff_err_msg += f'\nVariables in case1 but not in case2: {sorted(case1_minus_case2)}'
7✔
90
        if case2_minus_case1:
7✔
91
            diff_err_msg += f'\nVariables in case2 but not in case1: {sorted(case2_minus_case1)}'
7✔
92

93
    shape_errors = set()
7✔
94
    val_errors = {}
7✔
95
    shape_err_msg = '\nThe following variables have different shapes/sizes:'
7✔
96
    val_err_msg = io.StringIO()
7✔
97

98
    for var in sorted(set(case1_vars.keys()).intersection(case2_vars.keys())):
7✔
99
        a = case1_vars[var]['val']
7✔
100
        b = case2_vars[var]['val']
7✔
101
        if a.shape != b.shape:
7✔
102
            shape_errors.add(var)
7✔
103
            shape_err_msg += f'\n{var} has shape {a.shape} in case1 but shape {b.shape} in case2'
7✔
104
            continue
7✔
105
        err = np.abs(a - b)
7✔
106
        max_err = np.max(err)
7✔
107
        mean_err = np.mean(err)
7✔
108
        if np.any(max_err > tol):
7✔
109
            val_errors[var] = (max_err, mean_err)
7✔
110

111
    err_msg = ''
7✔
112
    if diff_err_msg:
7✔
113
        err_msg += diff_err_msg
7✔
114
    if shape_errors:
7✔
115
        err_msg += shape_err_msg
7✔
116
    if val_errors:
7✔
117
        val_err_msg.write('\nThe following variables contain different values:\n')
7✔
118
        max_var_len = max(3, max([len(s) for s in val_errors.keys()]))
7✔
119
        val_err_msg.write(
7✔
120
            f"{'var'.rjust(max_var_len)} {'max error'.rjust(16)} {'mean error'.rjust(16)}\n")
121
        val_err_msg.write(max_var_len * '-' + ' ' + 16 * '-' + ' ' + 16 * '-' + '\n')
7✔
122
        for varname, (max_err, mean_err) in val_errors.items():
7✔
123
            val_err_msg.write(f"{varname.rjust(max_var_len)} {max_err:16.9e} {mean_err:16.9e}\n")
7✔
124
        err_msg += val_err_msg.getvalue()
7✔
125

126
    if err_msg:
7✔
127
        raise AssertionError(err_msg)
7✔
128

129

130
def _write_out_timeseries_values_out_of_tolerance(isclose, rel_tolerance, abs_tolerance,
7✔
131
                                                  t_check, x_check, x_ref):
132
    """
133
    Helper function used to write out a table of values indicating which timeseries values
134
    were out of tolerance.
135

136
    Parameters
137
    ----------
138
    isclose : array of bool
139
        Boolean array indicating where data value is in tolerance. Has same shape as the
140
        time series array
141
    rel_tolerance : float
142
        Allowed relative tolerance error
143
    abs_tolerance : float
144
        Allowed absolute tolerance error
145
    t_check : np.array
146
        Array of time values for the timeseries
147
    x_check : np.array
148
        Array of data values for the timeseries to be check/compared to the reference value, x_ref
149
    x_ref : np.array
150
        Array of data values for the timeseries to be used as the reference
151
    """
152
    err_msg = f"The following timeseries data are out of tolerance due to absolute (" \
7✔
153
              f"{abs_tolerance}) or relative ({rel_tolerance}) tolerance violations\n"
154
    header = f"{'time_index':10s} | " + \
7✔
155
             f"{'data_indices':12s} | " + \
156
             f"{'time':13s} | " + \
157
             f"{'ref_data':13s} | " + \
158
             f"{'checked_data':13s} | " + \
159
             f"{'abs_error':13s} | " + \
160
             f"{'rel_error':13} | " + \
161
             " ABS or REL error "
162
    err_msg += f"{header}\n"
7✔
163
    err_msg += len(header) * '-' + '\n'
7✔
164

165
    rel_error_max = 0.0
7✔
166
    err_line_max = 0
7✔
167
    for idx, item_close in np.ndenumerate(isclose):
7✔
168
        if not item_close:
7✔
169
            error_string = ''
7✔
170
            abs_error = abs(x_check[idx] - x_ref[idx])
7✔
171
            if x_ref[idx] != 0.0:
7✔
172
                rel_error = abs(x_check[idx] - x_ref[idx]) / abs(x_ref[idx])
7✔
173
            else:
174
                rel_error = float('nan')
×
175
            if abs_tolerance is not None:
7✔
176
                if abs_error > abs_tolerance:
7✔
177
                    error_string += ' >ABS_TOL'
7✔
178

179
            if rel_tolerance is not None:
7✔
180
                if rel_error > rel_tolerance:
7✔
181
                    error_string += ' >REL_TOL'
7✔
182

183
            err_line = f"{idx[0]:10,d} | {str(idx[1:]):>12s} | {t_check[idx[0]]:13.6e} |" \
7✔
184
                       f"{x_ref[idx]:13.6e} | {x_check[idx]:13.6e} | {abs_error:13.6e} | " \
185
                       f"{rel_error:13.6e} | {error_string}\n"
186
            err_msg += err_line
7✔
187

188
            if rel_error > rel_error_max:
7✔
189
                rel_error_max = rel_error
7✔
190
                err_line_max = err_line
7✔
191

192
    # show the item with the max rel error
193
    max_rel_error_header_txt = 'Time series data value with the largest relative error'
7✔
194
    max_rel_error_msg = f"\n{len(max_rel_error_header_txt) * '#'}\n{max_rel_error_header_txt}\n" \
7✔
195
                        f"{len(max_rel_error_header_txt) * '#'}\n"
196

197
    max_rel_error_msg += f"{header}\n"
7✔
198
    max_rel_error_msg += len(header) * '-' + '\n'
7✔
199
    max_rel_error_msg += err_line_max
7✔
200

201
    err_msg += max_rel_error_msg
7✔
202

203
    return err_msg
7✔
204

205

206
def assert_timeseries_near_equal(t_ref, x_ref, t_check, x_check, abs_tolerance=None,
7✔
207
                                 rel_tolerance=None):
208
    """
209
    Assert that two timeseries of data are approximately equal.
210

211
    The first timeseries, defined by t_ref, x_ref, serves as the reference.
212

213
    The second timeseries, defined by t_check, x_check is what is checked for near equality.
214

215
    The check is done by fitting a 1D interpolant to the reference, and then comparing
216
    the values of the interpolant at the times in t_check. The check for errors within
217
    tolerance are done on a point-by-point basis. If any point is out of tolerance, throw
218
    an AssertionError.
219

220
    Only the times where the two timeseries overlap are used for the check.
221

222
    When both abs_tolerance and rel_tolerance are given, only one is actually used for any given
223
    data point. When the absolute values of the data values are small, the abs_tolerance is used,
224
    otherwise the rel_tolerance is used. The transition point is given by
225

226
        abs_tolerance / rel_tolerance
227

228
    Parameters
229
    ----------
230
    t_ref : np.array
231
        Time values for the reference timeseries.
232
    x_ref : np.array
233
        Data values for the reference timeseries.
234
    t_check : np.array
235
        Time values for the timeseries that is compared to the reference.
236
    x_check : np.array
237
        Data values for the timeseries that is compared to the reference.
238
    abs_tolerance : float
239
        The absolute tolerance for any errors along at each point checked.
240
    rel_tolerance : float
241
        The relative tolerance for any errors along at each point checked.
242

243
    Raises
244
    ------
245
    AssertionError
246
        When one or more elements of the timeseries to be checked are not with in the desired
247
        tolerance of the interpolated reference timeseries, an AssertionError is raised.
248
    """
249
    # get shapes for the time series values
250
    shape_ref = x_ref.shape[1:]
7✔
251
    shape_check = x_check.shape[1:]
7✔
252

253
    if abs_tolerance is None and rel_tolerance is None:
7✔
254
        raise ValueError('abs_tolerance and rel_tolerance cannot be both None')
×
255

256
    if shape_ref != shape_check:
7✔
257
        raise ValueError('The shape of the variable in the two timeseries is not equal '
7✔
258
                         f'x_ref is {shape_ref}  x_check is {shape_check}')
259

260
    # get the overlapping time period between t_ref and t_check
261
    t_begin = max(t_ref[0], t_check[0])
7✔
262
    t_end = min(t_ref[-1], t_check[-1])
7✔
263

264
    if t_begin > t_end:
7✔
265
        raise ValueError("There is no overlapping time between the two time series")
7✔
266

267
    # Flatten the timeseries data arrays
268
    num_elements = np.prod(shape_ref, dtype=int)
7✔
269
    time_series_len = x_ref.shape[0]
7✔
270
    x_ref_data_flattened = np.reshape(x_ref, (time_series_len, num_elements))
7✔
271
    t_ref_unique, idxs_unique_ref = np.unique(t_ref.ravel(), return_index=True)
7✔
272
    x_to_interp = x_ref_data_flattened[idxs_unique_ref, ...]
7✔
273
    t_check = t_check.ravel()
7✔
274

275
    interp = Akima1DInterpolator(t_ref_unique, x_to_interp)
7✔
276

277
    # only want t_check in the overlapping range of t_begin and t_end
278
    t_check_in_range_condition = np.logical_and(t_check >= t_begin, t_check <= t_end)
7✔
279
    t_check = np.compress(t_check_in_range_condition, t_check)
7✔
280
    x_check = np.compress(t_check_in_range_condition, x_check, axis=0)
7✔
281

282
    # get the interpolated values of the reference at the values of t_check
283
    # Reshape back to unflattened data values
284
    x_ref_interp = np.reshape(interp(t_check), (t_check.size,) + shape_ref)
7✔
285

286
    if abs_tolerance is None:  # so only have rel_tolerance
7✔
287
        isclose = np.isclose(x_check, x_ref_interp, rtol=rel_tolerance, atol=0.0)
7✔
288
        all_close = np.all(isclose)
7✔
289
        if not all_close:
7✔
290
            err_msg = _write_out_timeseries_values_out_of_tolerance(isclose,
7✔
291
                                                                    rel_tolerance,
292
                                                                    abs_tolerance,
293
                                                                    t_check,
294
                                                                    x_check,
295
                                                                    x_ref_interp,
296
                                                                    )
297
            raise AssertionError(err_msg)
7✔
298
    elif rel_tolerance is None:  # so only have abs_tolerance
7✔
299
        isclose = np.isclose(x_check, x_ref_interp, rtol=0.0, atol=abs_tolerance)
7✔
300
        all_close = np.all(isclose)
7✔
301
        if not all_close:
7✔
302
            err_msg = _write_out_timeseries_values_out_of_tolerance(isclose,
7✔
303
                                                                    rel_tolerance,
304
                                                                    abs_tolerance,
305
                                                                    t_check,
306
                                                                    x_check,
307
                                                                    x_ref_interp,
308
                                                                    )
309
            raise AssertionError(err_msg)
7✔
310
    else:  # need to use a hybrid of abs and rel tolerances
311
        err_msg = ''
7✔
312

313
        # At what value of absolute value of the data does the tolerance check switch between
314
        #    using the absolute vs relative tolerance
315
        transition_tolerance = abs_tolerance / rel_tolerance
7✔
316

317
        # for values > transition_tolerance, use rel_tolerance
318
        transition_condition = abs(x_ref_interp) >= transition_tolerance
7✔
319
        above_transition_x_ref_interp = np.full(x_ref_interp.shape, np.nan)
7✔
320
        np.copyto(above_transition_x_ref_interp, x_ref_interp, where=transition_condition)
7✔
321
        above_transition_x_check = np.full(x_ref_interp.shape, np.nan)
7✔
322
        np.copyto(above_transition_x_check, x_check, where=transition_condition)
7✔
323
        isclose_using_rel_tolerance = np.isclose(above_transition_x_check,
7✔
324
                                                 above_transition_x_ref_interp,
325
                                                 rtol=rel_tolerance, atol=0.0, equal_nan=True)
326

327
        # for values < transition_tolerance, use abs_tolerance
328
        transition_condition = abs(x_ref_interp) < transition_tolerance
7✔
329
        below_transition_x_ref_interp = np.full(x_ref_interp.shape, np.nan)
7✔
330
        np.copyto(below_transition_x_ref_interp, x_ref_interp, where=transition_condition)
7✔
331
        below_transition_x_check = np.full(x_ref_interp.shape, np.nan)
7✔
332
        np.copyto(below_transition_x_check, x_check, where=transition_condition)
7✔
333
        isclose_using_abs_tolerance = np.isclose(below_transition_x_check,
7✔
334
                                                 below_transition_x_ref_interp, rtol=0.0,
335
                                                 atol=abs_tolerance, equal_nan=True)
336

337
        # combine the two
338
        isclose_using_both_tolerance = isclose_using_rel_tolerance & isclose_using_abs_tolerance
7✔
339
        all_close = np.all(isclose_using_both_tolerance)
7✔
340
        if not all_close:
7✔
341
            err_msg += _write_out_timeseries_values_out_of_tolerance(isclose_using_both_tolerance,
7✔
342
                                                                     rel_tolerance,
343
                                                                     abs_tolerance,
344
                                                                     t_check,
345
                                                                     x_check,
346
                                                                     x_ref_interp,
347
                                                                     )
348
        if err_msg:
7✔
349
            raise AssertionError(err_msg)
7✔
350

351

352
def _get_reports_dir(prob):
7✔
353
    # need this to work with older OM versions with old reports system API
354
    # reports API changed between 3.18 and 3.19, so handle it here in order to be able to
355
    #  test against older versions of openmdao
356
    if Version(openmdao_version) > Version("3.18"):
7✔
357
        return prob.get_reports_dir()
7✔
358

359
    from openmdao.utils.reports_system import get_reports_dir
×
360
    return get_reports_dir(prob)
×
361

362

363
class PhaseStub():
7✔
364
    """
365
    A stand-in for the Phase during config_io for testing.
366

367
    It just supports the classify_var method and returns "ode", the only value needed for unittests.
368
    """
369
    def __init__(self):
7✔
370
        self.nonlinear_solver = None
7✔
371
        self.linear_solver = None
7✔
372

373
    def classify_var(self, name):
7✔
374
        """
375
        A stand-in for classify_var that always sets the variable type to name.
376

377
        Parameters
378
        ----------
379
        name : str
380
            The name of the variable to classify.
381

382
        Returns
383
        -------
384
        str
385
            The variable classification.
386
        """
387
        return 'ode'
7✔
388

389

390
class SimpleODE(om.ExplicitComponent):
7✔
391
    """
392
    A simple ODE for testing purposes.
393

394
    Source: https://math.okstate.edu/people/yqwang/teaching/math4513_fall11/Notes/rungekutta.pdf
395

396
    Parameters
397
    ----------
398
    **kwargs : dict of keyword arguments
399
        Keyword arguments that will be mapped into the Component options.
400
    """
401
    def initialize(self):
7✔
402
        """
403
        Declare options for SimpleODE.
404
        """
405
        self.options.declare('num_nodes', types=(int,))
7✔
406

407
    def setup(self):
7✔
408
        """
409
        Add inputs and outputs to SimpleODE.
410
        """
411
        nn = self.options['num_nodes']
7✔
412
        self.add_input('x', shape=(nn,), units='s**2')
7✔
413
        self.add_input('t', shape=(nn,), units='s')
7✔
414
        self.add_input('p', shape=(nn,), units='s**2')
7✔
415

416
        self.add_output('x_dot', shape=(nn,), units='s')
7✔
417

418
        ar = np.arange(nn, dtype=int)
7✔
419
        self.declare_partials(of='x_dot', wrt='x', rows=ar, cols=ar, val=1.0)
7✔
420
        self.declare_partials(of='x_dot', wrt='t', rows=ar, cols=ar)
7✔
421
        self.declare_partials(of='x_dot', wrt='p', rows=ar, cols=ar, val=1.0)
7✔
422

423
    def compute(self, inputs, outputs):
7✔
424
        """
425
        Compute the outputs of SimpleVectorizedODE.
426

427
        Parameters
428
        ----------
429
        inputs : Vector
430
            Vector of inputs.
431
        outputs : Vector
432
            Vector of outputs.
433
        """
434
        x = inputs['x']
7✔
435
        t = inputs['t']
7✔
436
        p = inputs['p']
7✔
437
        outputs['x_dot'] = x - t**2 + p
7✔
438

439
    def compute_partials(self, inputs, partials):
7✔
440
        """
441
        Compute the partials of SimpleVectorizedODE.
442

443
        Parameters
444
        ----------
445
        inputs : Vector
446
            Vector of inputs.
447
        partials : Dictionary
448
            Vector of partials.
449
        """
450
        t = inputs['t']
7✔
451
        partials['x_dot', 't'] = -2*t
7✔
452

453

454
class SimpleVectorizedODE(om.ExplicitComponent):
7✔
455
    """
456
    A simple vector-valued ODE.
457

458
    Source: https://math.okstate.edu/people/yqwang/teaching/math4513_fall11/Notes/rungekutta.pdf
459

460
    Parameters
461
    ----------
462
    **kwargs : dict of keyword arguments
463
        Keyword arguments that will be mapped into the Component options.
464
    """
465
    def initialize(self):
7✔
466
        """
467
        Declare options for SimpleVectorizedODE.
468
        """
NEW
469
        self.options.declare('num_nodes', types=(int,))
×
470

471
    def setup(self):
7✔
472
        """
473
        Add inputs and outputs to SimpleVectorizedODE.
474
        """
NEW
475
        nn = self.options['num_nodes']
×
NEW
476
        self.add_input('z', shape=(nn, 2), units='s**2')
×
NEW
477
        self.add_input('t', shape=(nn,), units='s')
×
NEW
478
        self.add_input('p', shape=(nn,), units='s**2')
×
479

NEW
480
        self.add_output('z_dot', shape=(nn, 2), units='s')
×
481

NEW
482
        cs = np.repeat(np.arange(nn, dtype=int), 2)
×
NEW
483
        ar2 = np.arange(2 * nn, dtype=int)
×
NEW
484
        dzdot_dz_pattern = np.arange(2 * nn, step=2, dtype=int)
×
NEW
485
        self.declare_partials(of='z_dot', wrt='z', rows=dzdot_dz_pattern, cols=dzdot_dz_pattern, val=1.0)
×
NEW
486
        self.declare_partials(of='z_dot', wrt='t', rows=ar2, cols=cs)
×
NEW
487
        dzdot_dp_rows = np.arange(2 * nn, step=2, dtype=int)
×
NEW
488
        dzdot_dp_cols = np.arange(nn, dtype=int)
×
NEW
489
        self.declare_partials(of='z_dot', wrt='p', rows=dzdot_dp_rows, cols=dzdot_dp_cols, val=1.0)
×
490

491
    def compute(self, inputs, outputs):
7✔
492
        """
493
        Compute the outputs of SimpleVectorizedODE.
494

495
        Parameters
496
        ----------
497
        inputs : Vector
498
            Vector of inputs.
499
        outputs : Vector
500
            Vector of outputs.
501
        """
NEW
502
        z = inputs['z']
×
NEW
503
        t = inputs['t']
×
NEW
504
        p = inputs['p']
×
NEW
505
        outputs['z_dot'][:, 0] = z[:, 0] - t**2 + p
×
NEW
506
        outputs['z_dot'][:, 1] = 10 * t
×
507

508
    def compute_partials(self, inputs, partials):
7✔
509
        """
510
        Compute the partials of SimpleVectorizedODE.
511

512
        Parameters
513
        ----------
514
        inputs : Vector
515
            Vector of inputs.
516
        partials : Dictionary
517
            Vector of partials.
518
        """
NEW
519
        t = inputs['t']
×
NEW
520
        partials['z_dot', 't'][0::2] = -2*t
×
NEW
521
        partials['z_dot', 't'][1::2] = 10
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc