• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / openhtf / 15425700896

03 Jun 2025 07:05PM UTC coverage: 62.271% (-0.006%) from 62.277%
15425700896

Pull #1231

github

copybara-github
Support specifying a tolerance when using `assertMeasured`.

PiperOrigin-RevId: 766756498
Pull Request #1231: Support specifying a tolerance when using `assertMeasured`.

4 of 6 new or added lines in 1 file covered. (66.67%)

6 existing lines in 1 file now uncovered.

4689 of 7530 relevant lines covered (62.27%)

3.11 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.83
/openhtf/util/test.py
1
# Copyright 2016 Google Inc. All Rights Reserved.
2

3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6

7
#     http://www.apache.org/licenses/LICENSE-2.0
8

9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Unit test helpers for OpenHTF tests and phases.
15

16
This module provides some utility for unit testing OpenHTF test phases and
17
whole tests.  Primarily, there are:
18
1. Mechanisms to aid in running phases and tests.
19
2. Convenience methods to mock plugs.
20
3. Assertions to validate phase/test output.
21

22
The primary class in this module is the TestCase class, which is a subclass
23
of unittest.TestCase that provides some extra utility.  Use it the same way
24
you would use unittest.TestCase.  See below for examples.
25

26
Since the test executor manages the plugs, TestCase.plugs and
27
TestCase.auto_mock_plugs maybe be used to set or access plug instances.  Also
28
available is a test method decorator, @patch_plugs, but it is less flexible and
29
should be avoided in new code. In both cases, limit yourself to one phase/test
30
execution per test method to avoid surprises with plug lifetimes.
31

32
Lastly, while not implemented here, it's common to need to temporarily alter
33
configuration values for individual tests.  This can be accomplished with the
34
@CONF.save_and_restore decorator (see docs in configuration.py, examples below).
35

36
A few isolated examples, also see test/util/test_test.py for some usage:
37

38
  from openhtf.util import configuration
39
  from openhtf.util import test
40

41
  CONF = configuration.CONF
42

43
  import mytest  # Contains phases under test.
44

45
  class PhasesTest(test.TestCase):
46

47
    # Using TestCase.execute_phase_or_test, which allows more flexibility.
48
    def test_using_execute_phase_or_test(self):
49
      self.auto_mock_plugs(PlugA)
50
      # Use below stub object instead of PlugB.
51
      self.plugs[PlugB] = PlugBStub()
52
      self.plugs[PlugA].read_something.return_value = 1234
53

54
      # Run your OpenHTF phase/test, returning phase record. Do only one of
55
      # these per test method to avoid unexpected behavior with plugs.
56
      phase_record = self.execute_phase_or_test(mytest.first_phase)
57
      self.plugs[PlugA].read_something.assert_called_once_with()
58
      # assert* methods for checking phase/test records are defined in TestCase.
59
      self.assertPhaseContinue(phase_record)
60

61
    # Decorate with CONF.save_and_restore to temporarily set CONF values.
62
    # NOTE: This must come before yields_phases.
63
    @CONF.save_and_restore(phase_variance='test_phase_variance')
64
    # Decorate the test* method with this to be able to yield a phase to run it.
65
    @test.yields_phases
66
    def test_first_phase(self):
67
      phase_record = yield mytest.first_phase
68
      # Check a measurement value.
69
      self.assertMeasured(phase_record, 'my_measurement', 'value')
70
      # Check that the measurement outcome was PASS.
71
      self.assertMeasurementPass(phase_record, 'my_measurement')
72

73
    @test.patch_plugs(mock_my_plug='my_plug.MyPlug')
74
    def test_second_phase(self, mock_my_plug):  # arg must match keyword above.
75
      # mock_my_plug is a MagicMock, and will be passed to yielded test phases
76
      # in place of an instance of my_plug.MyPlug.  You can access it here to
77
      # configure return values (and later to assert calls to plug methods).
78
      mock_my_plug.measure_voltage.return_value = 5.0
79

80
      # Trigger a phase (or openhtf.Test instance) to run by yielding it.  The
81
      # resulting PhaseRecord is yielded back (or TestRecord if you yielded an
82
      # openhtf.Test instance instead of a phase).
83
      phase_record = yield mytest.second_phase  # uses my_plug.MyPlug
84

85
      # Apply assertions to the output, probably using utility methods on self,
86
      # see below for an exhaustive list of such utility assertions.
87
      self.assertPhaseContinue(phase_record)
88

89
      # You can apply any assertions on the mocked plug here.
90
      mock_my_plug.measure_voltage.assert_called_once_with()
91

92
      # If you want to patch the plugs yourself, use mock.patch(.object) on the
93
      # plug class; plug instances are available in the `plugs` attribute once
94
      # the phase/test has been run:
95
      self.plugs[my_plug.MyPlug].measure_voltage.assert_called_once_with()
96

97
    @test.patch_plugs(mock_my_plug='my_plug.MyPlug')
98
    def test_multiple(self, mock_my_plug):
99
      # You can also yield an entire openhtf.Test() object.  If you do, you get
100
      # a TestRecord yielded back instead of a PhaseRecord.
101
      test_rec = yield openhtf.Test(mytest.first_phase, mytest.second_phase)
102

103
      # Some utility assertions are provided for operating on test records (see
104
      # below for a full list).
105
      self.assertTestPass(test_rec)
106

107
List of assertions that can be used with PhaseRecord results:
108

109
  assertPhaseContinue(phase_record)
110
  assertPhaseRepeat(phase_record)
111
  assertPhaseStop(phase_record)
112
  assertPhaseError(phase_record, exc_type=None)
113

114
List of assertions that can be used with TestRecord results:
115

116
  assertTestPass(test_rec)
117
  assertTestFail(test_rec)
118
  assertTestError(test_rec, exc_type=None)
119
  assertTestOutcomeCode(test_rec, code)
120

121
List of assertions that can be used with either PhaseRecords or TestRecords:
122

123
  assertMeasured(phase_or_test_rec, measurement, value=mock.ANY)
124
  assertNotMeasured(phase_or_test_rec, measurement)
125
  assertMeasurementPass(phase_or_test_rec, measurement)
126
  assertMeasurementFail(phase_or_test_rec, measurement)
127
"""
128

129
from collections.abc import Callable as CollectionsCallable, Iterator
5✔
130
import contextlib
5✔
131
import functools
5✔
132
import inspect
5✔
133
import logging
5✔
134
import os
5✔
135
import pathlib
5✔
136
import pstats
5✔
137
import sys
5✔
138
import tempfile
5✔
139
import types
5✔
140
import typing
5✔
141
from typing import (
5✔
142
    Any,
143
    Callable,
144
    Dict,
145
    Iterable,
146
    List,
147
    Optional,
148
    Sequence,
149
    Text,
150
    Tuple,
151
    Type,
152
    Union,
153
)
154
import unittest
5✔
155
from unittest import mock
5✔
156

157
import attr
5✔
158
from openhtf import plugs
5✔
159
from openhtf import util
5✔
160
from openhtf.core import base_plugs
5✔
161
from openhtf.core import diagnoses_lib
5✔
162
from openhtf.core import measurements
5✔
163
from openhtf.core import phase_collections
5✔
164
from openhtf.core import phase_descriptor
5✔
165
from openhtf.core import phase_executor
5✔
166
from openhtf.core import phase_nodes
5✔
167
from openhtf.core import test_descriptor
5✔
168
from openhtf.core import test_executor
5✔
169
from openhtf.core import test_record
5✔
170
from openhtf.core import test_state
5✔
171
from openhtf.plugs import device_wrapping
5✔
172
from openhtf.util import logs
5✔
173
from openhtf.util import text
5✔
174

175
logs.CLI_LOGGING_VERBOSITY = 2
5✔
176

177
# TestApi.dut_id attribute when running unit tests with the module-supplied
178
# test start function.
179
TEST_DUT_ID = 'TestDutId'
5✔
180

181

182
# Maximum number of measurements per phase to be printed to the assertion
183
# error message for test failures.
184
_MAXIMUM_NUM_MEASUREMENTS_PER_PHASE = 10
5✔
185

186

187
class InvalidTestError(Exception):
5✔
188
  """Raised when there's something invalid about a test."""
189

190

191
class _ValidTimestamp(int):
5✔
192

193
  def __eq__(self, other):
5✔
194
    return other is not None and other > 0
5✔
195

196

197
VALID_TIMESTAMP = _ValidTimestamp()
5✔
198

199

200
@attr.s(slots=True, frozen=True)
5✔
201
class TestNode(phase_nodes.PhaseNode):
5✔
202
  """General base class for comparison nodes.
203

204
  This is used to test functions that create phase nodes; it cannot be run as
205
  part of an actual test run.
206
  """
207

208
  def copy(self: phase_nodes.WithModifierT) -> phase_nodes.WithModifierT:
5✔
209
    """Create a copy of the PhaseNode."""
210
    return self
5✔
211

212
  def with_args(self: phase_nodes.WithModifierT,
5✔
213
                **kwargs: Any) -> phase_nodes.WithModifierT:
214
    """Send these keyword-arguments when phases are called."""
215
    del kwargs  # Unused.
×
216
    return self
×
217

218
  def with_plugs(
5✔
219
      self: phase_nodes.WithModifierT,
220
      **subplugs: Type[base_plugs.BasePlug]) -> phase_nodes.WithModifierT:
221
    """Substitute plugs for placeholders for this phase, error on unknowns."""
222
    del subplugs  # Unused.
×
223
    return self
×
224

225
  def load_code_info(
5✔
226
      self: phase_nodes.WithModifierT) -> phase_nodes.WithModifierT:
227
    """Load coded info for all contained phases."""
228
    return self
×
229

230
  def apply_to_all_phases(self, func: Any) -> 'TestNode':
5✔
231
    return self
×
232

233

234
@attr.s(slots=True, frozen=True, eq=False)
5✔
235
class PhaseNodeNameComparable(TestNode):
5✔
236
  """Compares truthfully against any phase node with the same name.
237

238
  This is used to test functions that create phase nodes; it cannot be run as
239
  part of an actual test run.
240
  """
241

242
  name = attr.ib(type=Text)
5✔
243

244
  def _asdict(self) -> Dict[Text, Any]:
5✔
245
    """Returns a base type dictionary for serialization."""
246
    return {'name': self.name}
5✔
247

248
  def __eq__(self, other: phase_nodes.PhaseNode) -> bool:
5✔
249
    return self.name == other.name
5✔
250

251

252
@attr.s(slots=True, frozen=True, eq=False, init=False)
5✔
253
class PhaseNodeComparable(TestNode):
5✔
254
  """Compares truthfully only against another with same data.
255

256
  This is used to test functions that create phase nodes; it cannot be run as
257
  part of an actual test run.
258
  """
259

260
  name = attr.ib(type=Text)
5✔
261
  args = attr.ib(type=Tuple[Any, ...], factory=tuple)
5✔
262
  kwargs = attr.ib(type=Dict[Text, Any], factory=dict)
5✔
263

264
  def __init__(self, name, *args, **kwargs):
5✔
265
    super(PhaseNodeComparable, self).__init__()
×
266
    object.__setattr__(self, 'name', name)
×
267
    object.__setattr__(self, 'args', tuple(args))
×
268
    object.__setattr__(self, 'kwargs', kwargs)
×
269

270
  @classmethod
5✔
271
  def create_constructor(cls, name) -> Callable[..., 'PhaseNodeComparable']:
5✔
272

273
    def constructor(*args, **kwargs):
×
274
      return cls(name, *args, **kwargs)
×
275

276
    return constructor
×
277

278
  def _asdict(self) -> Dict[Text, Any]:
5✔
279
    return {'name': self.name, 'args': self.args, 'kwargs': self.kwargs}
×
280

281
  def __eq__(self, other: phase_nodes.PhaseNode) -> bool:
5✔
282
    return (isinstance(other, PhaseNodeComparable) and
×
283
            self.name == other.name and self.args == other.args and
284
            self.kwargs == other.kwargs)
285

286

287
class FakeTestApi(test_descriptor.TestApi):
5✔
288
  """A fake TestApi used to test non-phase helper functions."""
289

290
  def __init__(self):
5✔
291
    self.mock_logger = mock.create_autospec(logging.Logger)
×
292
    self.mock_phase_state = mock.create_autospec(
×
293
        test_state.PhaseState, logger=self.mock_logger)
294
    self.mock_test_state = mock.create_autospec(
×
295
        test_state.TestState,
296
        test_record=test_record.TestRecord('DUT', 'STATION'),
297
        user_defined_state={})
298
    super(FakeTestApi, self).__init__(
×
299
        measurements={},
300
        running_phase_state=self.mock_phase_state,
301
        running_test_state=self.mock_test_state)
302

303

304
def filter_phases_by_names(phase_recs: Iterable[test_record.PhaseRecord],
5✔
305
                           *names: Text) -> Iterable[test_record.PhaseRecord]:
306
  all_names = set(names)
5✔
307
  for phase_rec in phase_recs:
5✔
308
    if phase_rec.name in all_names:
5✔
309
      yield phase_rec
5✔
310

311

312
def filter_phases_by_outcome(
5✔
313
    phase_recs: Iterable[test_record.PhaseRecord],
314
    outcome: test_record.PhaseOutcome) -> Iterable[test_record.PhaseRecord]:
315
  for phase_rec in phase_recs:
×
316
    if phase_rec.outcome == outcome:
×
317
      yield phase_rec
×
318

319

320
def _merge_stats(stats: pstats.Stats, filepath: pathlib.Path) -> None:
5✔
321
  """Merges provides Stats into filepath (created if not present)."""
322
  stats_to_combine = [stats]
5✔
323
  try:
5✔
324
    stats_to_combine.append(pstats.Stats(str(filepath)))
5✔
325
  except FileNotFoundError:
5✔
326
    pass
5✔
327
  test_executor.combine_profile_stats(stats_to_combine, str(filepath))
5✔
328

329

330
class PhaseOrTestIterator(Iterator):
5✔
331

332
  def __init__(self, test_case, iterator, mock_plugs, phase_user_defined_state,
5✔
333
               phase_diagnoses):
334
    """Create an iterator for iterating over Tests or phases to run.
335

336
    This should only be instantiated internally.
337

338
    Args:
339
      test_case: TestCase subclass where the test case function is defined.
340
      iterator: Child iterator to use for obtaining Tests or test phases, must
341
        be a generator.
342
      mock_plugs: Dict mapping plug types to mock objects to use instead of
343
        actually instantiating that type.
344
      phase_user_defined_state: If not None, a dictionary that will be added to
345
        the test_state.user_defined_state when handling phases.
346
      phase_diagnoses: If not None, must be a list of Diagnosis instances; these
347
        are added to the DiagnosesManager when handling phases.
348

349
    Raises:
350
      InvalidTestError: when iterator is not a generator.
351
    """
352
    if not isinstance(iterator, types.GeneratorType):
5✔
353
      raise InvalidTestError(
5✔
354
          'Methods decorated with patch_plugs or yields_phases must yield '
355
          'test phases or openhtf.Test objects.', iterator)
356

357
    # Since we want to run single phases, we instantiate our own PlugManager.
358
    # Don't do this sort of thing outside OpenHTF unless you really know what
359
    # you're doing (http://imgur.com/iwBCmQe).
360
    self.plug_manager = plugs.PlugManager()
5✔
361
    self.test_case = test_case
5✔
362
    self.iterator = iterator
5✔
363
    self.mock_plugs = mock_plugs
5✔
364
    self.last_result = None
5✔
365
    if not phase_user_defined_state:
5✔
366
      phase_user_defined_state = {}
5✔
367
    self.phase_user_defined_state = phase_user_defined_state
5✔
368
    if not phase_diagnoses:
5✔
369
      phase_diagnoses = []
5✔
370
    self.phase_diagnoses = phase_diagnoses
5✔
371

372
  def _initialize_plugs(self, plug_types):
5✔
373
    # Make sure we initialize any plugs, this will ignore any that have already
374
    # been initialized.
375
    plug_types = list(plug_types)
5✔
376
    self.plug_manager.initialize_plugs(
5✔
377
        plug_cls for plug_cls in plug_types if plug_cls not in self.mock_plugs)
378
    for plug_type, plug_value in self.mock_plugs.items():
5✔
379
      self.plug_manager.update_plug(plug_type, plug_value)
5✔
380
    for plug_type in plug_types:
5✔
381
      self.test_case.plugs[plug_type] = (
5✔
382
          self.plug_manager.get_plug_by_class_path(
383
              self.plug_manager.get_plug_name(plug_type)))
384

385
  def _handle_phase(self, phase_desc):
5✔
386
    """Handle execution of a single test phase."""
387
    phase_descriptor.check_for_duplicate_results(iter([phase_desc]), [])
5✔
388
    logs.configure_logging()
5✔
389
    self._initialize_plugs(phase_plug.cls for phase_plug in phase_desc.plugs)
5✔
390

391
    # Cobble together a fake TestState to pass to the test phase.
392
    test_options = test_descriptor.TestOptions()
5✔
393
    with mock.patch.object(
5✔
394
        plugs, 'PlugManager', new=lambda _, __: self.plug_manager):
395
      test_state_ = test_state.TestState(
5✔
396
          test_descriptor.TestDescriptor(
397
              phase_collections.PhaseSequence((phase_desc,)),
398
              phase_desc.code_info, {}), 'Unittest:StubTest:UID', test_options)
399
      test_state_.mark_test_started()
5✔
400

401
    test_state_.user_defined_state.update(self.phase_user_defined_state)
5✔
402
    for diag in self.phase_diagnoses:
5✔
403
      test_state_.diagnoses_manager._add_diagnosis(diag)  # pylint: disable=protected-access
5✔
404
      test_state_.test_record.add_diagnosis(diag)
5✔
405

406
    # Save the test_state to the last_test_case attribute to give it access to
407
    # the underlying state.
408
    self.test_case.last_test_state = test_state_
5✔
409

410
    # Actually execute the phase, saving the result in our return value.
411
    executor = phase_executor.PhaseExecutor(test_state_)
5✔
412
    profile_filepath = self.test_case.get_profile_filepath()
5✔
413
    # Log an exception stack when a Phase errors out.
414
    with mock.patch.object(
5✔
415
        phase_executor.PhaseExecutorThread,
416
        '_log_exception',
417
        side_effect=logging.exception):
418
      # Use _execute_phase_once because we want to expose all possible outcomes.
419
      phase_result, profile_stats = executor._execute_phase_once(
5✔
420
          phase_desc,
421
          is_last_repeat=False,
422
          run_with_profiling=profile_filepath,
423
          subtest_rec=None)
424

425
    if profile_filepath is not None:
5✔
426
      _merge_stats(profile_stats, profile_filepath)
5✔
427

428
    if phase_result.raised_exception:
5✔
429
      failure_message = phase_result.phase_result.get_traceback_string()
5✔
430
    else:
431
      failure_message = None
5✔
432
    return test_state_.test_record.phases[-1], failure_message
5✔
433

434
  def _handle_test(self, test):
5✔
435
    self._initialize_plugs(test.descriptor.plug_types)
5✔
436

437
    # We'll need a place to stash the resulting TestRecord.
438
    record_saver = util.NonLocalResult()
5✔
439
    test.add_output_callbacks(
5✔
440
        lambda record: setattr(record_saver, 'result', record))
441

442
    profile_filepath = self.test_case.get_profile_filepath()
5✔
443
    if profile_filepath is None:
5✔
444
      profile_tempfile = None
5✔
445
    else:
446
      profile_tempfile = tempfile.NamedTemporaryFile(delete=False)
5✔
447
    # Mock the PlugManager to use ours instead, and execute the test.
448
    with mock.patch.object(
5✔
449
        plugs, 'PlugManager', new=lambda _, __: self.plug_manager):
450
      test.execute(
5✔
451
          test_start=self.test_case.test_start_function,
452
          profile_filename=(None if profile_tempfile is None else
453
                            profile_tempfile.name))
454

455
    if profile_tempfile is not None:
5✔
456
      _merge_stats(pstats.Stats(profile_tempfile.name), profile_filepath)
5✔
457
      profile_tempfile.close()
5✔
458

459
    test_record_ = record_saver.result
5✔
460
    if test_record_.outcome_details:
5✔
461
      msgs = []
5✔
462
      for detail in test_record_.outcome_details:
5✔
463
        msgs.append('code: {}\ndescription: {}'.format(detail.code,
5✔
464
                                                       detail.description))
465
      failure_message = '\n'.join(msgs)
5✔
466
    else:
467
      failure_message = None
5✔
468
    return test_record_, failure_message
5✔
469

470
  def __next__(self):
5✔
471
    phase_or_test = self.iterator.send(self.last_result)
5✔
472
    if isinstance(phase_or_test, test_descriptor.Test):
5✔
473
      self.last_result, failure_message = self._handle_test(phase_or_test)
5✔
474
    elif not isinstance(phase_or_test, CollectionsCallable):
5✔
475
      raise InvalidTestError(
5✔
476
          'methods decorated with patch_plugs must yield Test instances or '
477
          'individual test phases', phase_or_test)
478
    else:
479
      self.last_result, failure_message = self._handle_phase(
5✔
480
          phase_descriptor.PhaseDescriptor.wrap_or_copy(phase_or_test))
481
    return phase_or_test, self.last_result, failure_message
5✔
482

483
  def next(self):
5✔
484
    phase_or_test = self.iterator.send(self.last_result)
×
485
    if isinstance(phase_or_test, test_descriptor.Test):
×
486
      self.last_result, failure_message = self._handle_test(phase_or_test)
×
487
    elif not isinstance(phase_or_test, CollectionsCallable):
×
488
      raise InvalidTestError(
×
489
          'methods decorated with patch_plugs must yield Test instances or '
490
          'individual test phases', phase_or_test)
491
    else:
492
      self.last_result, failure_message = self._handle_phase(
×
493
          phase_descriptor.PhaseDescriptor.wrap_or_copy(phase_or_test))
494
    return phase_or_test, self.last_result, failure_message
×
495

496

497
def yields_phases(func):
5✔
498
  """Alias for patch_plugs with no plugs patched."""
499
  return patch_plugs()(func)
5✔
500

501

502
def yields_phases_with(phase_user_defined_state=None, phase_diagnoses=None):
5✔
503
  """Apply patch_plugs with no plugs, but add test state modifications."""
504
  return patch_plugs(
5✔
505
      phase_user_defined_state=phase_user_defined_state,
506
      phase_diagnoses=phase_diagnoses)
507

508

509
def patch_plugs(phase_user_defined_state=None,
5✔
510
                phase_diagnoses=None,
511
                **mock_plugs):
512
  """Decorator for mocking plugs for a test phase.
513

514
  Usage:
515

516
    @plugs(my_plug=my_plug_module.MyPlug)
517
    def my_phase_that_uses_my_plug(test, my_plug):
518
      test.logger.info('Something: %s', my_plug.do_something(10))
519

520
    @test.patch_plugs(my_plug_mock='my_plug_module.MyPlug')
521
    def test_my_phase(self, my_plug_mock):
522
      # Set up return value for the do_something method on our plug.
523
      my_plug_mock.do_something.return_value = 'mocked_value'
524

525
      # Yield the phase you wish to test. Typically it wouldn't be in the same
526
      # module like this, but this works for example purposes.
527
      yield my_phase_that_uses_my_plug
528

529
      # Do some assertions to make sure our plug was used correctly.
530
      my_plug_mock.do_something.assert_called_with(10)
531

532
    # Passing in the plug class itself also works and can be beneficial
533
    # when the module path is long.
534
    @test.patch_plugs(my_plug_mock=my_plug_module.MyPlug)
535
    def test_my_phase_again(self, my_plug_mock):
536
      pass
537

538
  Args:
539
    phase_user_defined_state: If specified, a dictionary that will be added to
540
      the test_state.user_defined_state when handling phases.
541
    phase_diagnoses: If specified, must be a list of Diagnosis instances; these
542
      are added to the DiagnosesManager when handling phases.
543
    **mock_plugs: kwargs mapping argument name to be passed to the test case to
544
      a string describing the plug type to mock.  The corresponding mock will be
545
      passed to the decorated test case as a keyword argument.
546

547
  Returns:
548
    Function decorator that mocks plugs.
549
  """
550
  if phase_diagnoses:
5✔
551
    for diag in phase_diagnoses:
5✔
552
      assert isinstance(diag, diagnoses_lib.Diagnosis)
5✔
553

554
  def test_wrapper(test_func):
5✔
555
    plug_argspec = inspect.getfullargspec(test_func)
5✔
556
    num_defaults = len(plug_argspec.defaults or ())
5✔
557
    plug_args = set(plug_argspec.args[1:-num_defaults or None])
5✔
558

559
    # Some sanity checks to make sure the mock arg names match.
560
    for arg in plug_args:
5✔
561
      if arg not in mock_plugs:
5✔
562
        raise InvalidTestError(
5✔
563
            'Test method %s expected arg %s, but it was not provided in '
564
            'patch_plugs kwargs: ' % (test_func.__name__, arg), mock_plugs)
565
    for mock_name in mock_plugs:
5✔
566
      if mock_name not in plug_args:
5✔
567
        raise InvalidTestError(
5✔
568
            'patch_plugs got kwarg %s, but test method %s does not expect '
569
            'it.' % (mock_name, test_func.__name__), plug_args)
570

571
    # Make MagicMock instances for the plugs.
572
    plug_kwargs = {}  # kwargs to pass to test func.
5✔
573
    plug_typemap = {}  # typemap for PlugManager, maps type to instance.
5✔
574
    for plug_arg_name, plug_fullname in mock_plugs.items():
5✔
575
      if isinstance(plug_fullname, str):
5✔
576
        try:
5✔
577
          plug_module, plug_typename = plug_fullname.rsplit('.', 1)
5✔
578
          plug_type = getattr(sys.modules[plug_module], plug_typename)
5✔
579
        except Exception:
5✔
580
          logging.error("Invalid plug type specification %s='%s'",
5✔
581
                        plug_arg_name, plug_fullname)
582
          raise
5✔
583
      elif issubclass(plug_fullname, base_plugs.BasePlug):
×
584
        plug_type = plug_fullname
×
585
      else:
586
        raise ValueError('Invalid plug type specification %s="%s"' %
×
587
                         (plug_arg_name, plug_fullname))
588
      if issubclass(plug_type, device_wrapping.DeviceWrappingPlug):
5✔
589
        # We can't strictly spec because calls to attributes are proxied to the
590
        # underlying device.
591
        plug_mock = mock.MagicMock()
×
592
      else:
593
        plug_mock = mock.create_autospec(
5✔
594
            plug_type, spec_set=True, instance=True)
595
      plug_typemap[plug_type] = plug_mock
5✔
596
      plug_kwargs[plug_arg_name] = plug_mock
5✔
597

598
    # functools.wraps is more than just aesthetic here, we need the original
599
    # name to match so we don't mess with unittest's TestLoader mechanism.
600
    @functools.wraps(test_func)
5✔
601
    def wrapped_test(self):
5✔
602
      self.assertIsInstance(
5✔
603
          self,
604
          TestCase,
605
          msg='Must derive from openhtf.util.test.TestCase '
606
          'to use yields_phases/patch_plugs.')
607
      plug_mocks = dict(self.plugs)
5✔
608
      plug_mocks.update(plug_typemap)
5✔
609
      for phase_or_test, result, failure_message in PhaseOrTestIterator(
5✔
610
          self, test_func(self, **plug_kwargs), plug_mocks,
611
          phase_user_defined_state, phase_diagnoses):
612
        logging.info('Ran %s, result: %s', phase_or_test, result)
5✔
613
        if failure_message:
5✔
614
          logging.error('Reported error:\n%s', failure_message)
5✔
615

616
    return wrapped_test
5✔
617

618
  return test_wrapper
5✔
619

620

621
def _assert_phase_or_test_record(func):
5✔
622
  """Decorator for automatically invoking self.assertTestPhases when needed.
623

624
  This allows assertions to apply to a single phase or "any phase in the test"
625
  without having to handle the type check themselves.  Note that the record,
626
  either PhaseRecord or TestRecord, must be the first argument to the
627
  wrapped assertion method.
628

629
  In the case of a TestRecord, the assertion will pass if *any* PhaseRecord in
630
  the TestRecord passes, otherwise the *last* exception raised will be
631
  re-raised.
632

633
  Args:
634
    func: the function to wrap.
635

636
  Returns:
637
    Function decorator.
638
  """
639

640
  @functools.wraps(func)
5✔
641
  def assertion_wrapper(self, phase_or_test_record, *args, **kwargs):
5✔
642
    if isinstance(phase_or_test_record, test_record.TestRecord):
5✔
643
      original_exception = None
5✔
644
      for phase_record in phase_or_test_record.phases:
5✔
645
        try:
5✔
646
          func(self, phase_record, *args, **kwargs)
5✔
647
          break
5✔
648
        except Exception as e:  # pylint: disable=broad-except
5✔
649
          original_exception = e
5✔
650
      else:
651
        if original_exception is not None:
5✔
652
          raise original_exception
5✔
653
    elif isinstance(phase_or_test_record, test_record.PhaseRecord):
5✔
654
      func(self, phase_or_test_record, *args, **kwargs)
5✔
655
    else:
656
      raise InvalidTestError('Expected either a PhaseRecord or TestRecord')
5✔
657

658
  return assertion_wrapper
5✔
659

660

661
class TestCase(unittest.TestCase):
5✔
662
  # Configure this via set_profile_dir().
663
  _profile_output_dir: Optional[pathlib.Path] = None
5✔
664

665
  def __init__(self, methodName=None):
5✔
666
    super(TestCase, self).__init__(methodName=methodName)
5✔
667
    if methodName != 'runTest':
5✔
668
      test_method = getattr(self, methodName)
5✔
669
      if inspect.isgeneratorfunction(test_method):
5✔
670
        raise ValueError('%s yields without @openhtf.util.test.yields_phases' %
×
671
                         methodName)
672

673
  def setUp(self):
5✔
674
    super(TestCase, self).setUp()
5✔
675
    # When a phase is yielded to a yields_phases/patch_plugs function, this
676
    # attribute will be set to the openhtf.core.test_state.TestState used in the
677
    # phase execution.
678
    self.last_test_state = None
5✔
679
    # When a test is yielded, this function is provided to as the test_start
680
    # argument to test.execute.
681
    self.test_start_function = lambda: TEST_DUT_ID
5✔
682
    # Dictionary mapping plug class (type, not instance) to plug instance.
683
    # Prior to executing a phase or test, plug instances can be added here.
684
    # When a OpenHTF phase or test is run in this suite, any instantiated plugs
685
    # will be available here.
686
    # "Any" hint below needed because pytype doesn't like heterogeneous values.
687
    self.plugs = {}  # type: Any
5✔
688

689
  def auto_mock_plugs(self, *plug_types: Type[plugs.BasePlug]):
5✔
690
    """Specifies plugs that may be automatically mocked if needed.
691

692
    Can be called from setUp, or from inside a test case.
693

694
    Plug mocks created by this method will not be used if set directly in the
695
    `plug` attribute in this instance. Mocks use autospec and spec_set, and so
696
    this method should not be used for plugs where this isn't desired.
697

698
    Args:
699
      *plug_types: Plug classes for which mocks should be used.
700
    """
701
    for plug_type in plug_types:
5✔
702
      if plug_type in self.plugs:
5✔
703
        logging.info(
×
704
            'Plug "%s" already has mock in self.plugs; skipping '
705
            'automatic mock', plug_type.__name__)
706
        continue
×
707
      self.plugs[plug_type] = mock.create_autospec(
5✔
708
          plug_type, spec_set=True, instance=True)
709

710
  @typing.overload
5✔
711
  def execute_phase_or_test(
5✔
712
      self,
713
      phase_or_test: test_descriptor.Test,
714
      phase_user_defined_state: None = None,  # Only supported for phases.
715
      phase_diagnoses: None = None,  # Only supported for phases.
716
  ) -> test_record.TestRecord:
717
    ...
×
718

719
  @typing.overload
5✔
720
  def execute_phase_or_test(
5✔
721
      self,
722
      phase_or_test: phase_descriptor.PhaseT,
723
      # Pytype does not correctly support heterogeneous dict values, hence Any.
724
      phase_user_defined_state: Optional[Any] = None,
725
      phase_diagnoses: Optional[Iterable[diagnoses_lib.Diagnosis]] = None,
726
  ) -> test_record.PhaseRecord:
727
    ...
×
728

729
  def execute_phase_or_test(self,
5✔
730
                            phase_or_test,
731
                            phase_user_defined_state=None,
732
                            phase_diagnoses=None):
733
    """Executes the provided Test or Phase, returning corresponding record.
734

735
    Args:
736
      phase_or_test: The Test or phase to execute.
737
      phase_user_defined_state: If specified, a dictionary that will be added to
738
        the test_state.user_defined_state when handling phases. This is only
739
        supported when executing a phase.
740
      phase_diagnoses: If specified, must be a list of Diagnosis instances;
741
        these are added to the DiagnosesManager when handling phases.
742

743
    Returns:
744
      Test or phase record for the execution. See various assert* methods in
745
      this class for possible testing.
746
    """
747

748
    def phase_generator():
5✔
749
      phase_or_test_record = yield phase_or_test
5✔
750
      return phase_or_test_record
5✔
751

752
    for phase_or_test, result, failure_message in PhaseOrTestIterator(
5✔
753
        self, phase_generator(), self.plugs, phase_user_defined_state,
754
        phase_diagnoses):
755
      logging.info('Ran %s, result: %s', phase_or_test, result)
5✔
756
      if failure_message:
5✔
757
        logging.error('Reported error:\n%s', failure_message)
5✔
758
    # Pylint cannot determine that the loop above executes for exactly one
759
    # iteration, in any path that would lead here.
760
    return result  # pylint: disable=undefined-loop-variable
5✔
761

762
  ##### TestRecord Assertions #####
763

764
  def assertTestPass(self, test_rec):
5✔
765
    self.assertEqual(
5✔
766
        test_record.Outcome.PASS,
767
        test_rec.outcome,
768
        msg='\n\n{}'.format(
769
            text.StringFromTestRecord(
770
                test_rec,
771
                only_failures=True,
772
                maximum_num_measurements=_MAXIMUM_NUM_MEASUREMENTS_PER_PHASE)))
773

774
  def assertTestFail(self, test_rec):
5✔
775
    msg = None
5✔
776
    if test_rec.outcome == test_record.Outcome.ERROR:
5✔
777
      msg = text.StringFromOutcomeDetails(test_rec.outcome_details)
×
778
    self.assertEqual(test_record.Outcome.FAIL, test_rec.outcome, msg=msg)
5✔
779

780
  def assertTestAborted(self, test_rec):
5✔
781
    self.assertEqual(test_record.Outcome.ABORTED, test_rec.outcome)
5✔
782

783
  def assertTestError(self, test_rec, exc_type=None):
5✔
784
    self.assertEqual(test_record.Outcome.ERROR, test_rec.outcome)
5✔
785
    if exc_type is not None:
5✔
786
      self.assertPhaseError(test_rec.phases[-1], exc_type)
5✔
787

788
  def assertTestOutcomeCode(self, test_rec, code):
5✔
789
    """Assert that the given code is in some OutcomeDetails in the record."""
790
    self.assertTrue(
5✔
791
        any(details.code == code for details in test_rec.outcome_details),
792
        'No OutcomeDetails had code %s' % code)
793

794
  @contextlib.contextmanager
5✔
795
  def assertTestHasPhaseRecord(self, test_rec, phase_name):
5✔
796
    """Yields a PhaseRecord with the given name, else asserts."""
797
    all_phase_names = []
5✔
798
    expected_phase_rec = None
5✔
799
    for phase_rec in test_rec.phases:
5✔
800
      all_phase_names.append(phase_rec.name)
5✔
801
      if phase_rec.name == phase_name:
5✔
802
        expected_phase_rec = phase_rec
5✔
803
    self.assertIsNotNone(
5✔
804
        expected_phase_rec,
805
        msg=f'Phase "{phase_name}" not found in test phases: {all_phase_names}',
806
    )
807
    yield expected_phase_rec
5✔
808

809
  ##### PhaseRecord Assertions #####
810

811
  def assertPhaseContinue(self, phase_record):
5✔
812
    self.assertIs(
5✔
813
        phase_descriptor.PhaseResult.CONTINUE,
814
        phase_record.result.phase_result,
815
        msg='\n\n{}'.format(
816
            text.StringFromPhaseRecord(
817
                phase_record,
818
                only_failures=True,
819
                maximum_num_measurements=_MAXIMUM_NUM_MEASUREMENTS_PER_PHASE)))
820

821
  def assertPhaseFailAndContinue(self, phase_record):
5✔
822
    msg = None
×
823
    if phase_record.result.raised_exception is not None:
×
824
      msg = ('The following exception was raised: '
×
825
             f'{phase_record.result.phase_result}.')
826
    self.assertIs(
×
827
        phase_descriptor.PhaseResult.FAIL_AND_CONTINUE,
828
        phase_record.result.phase_result,
829
        msg=msg)
830

831
  def assertPhaseFailSubtest(self, phase_record):
5✔
832
    msg = None
5✔
833
    if phase_record.result.raised_exception is not None:
5✔
834
      msg = (f'The following exception was raised: '
5✔
835
             f'{phase_record.result.phase_result}.')
836
    self.assertIs(
5✔
837
        phase_descriptor.PhaseResult.FAIL_SUBTEST,
838
        phase_record.result.phase_result,
839
        msg=msg)
840

841
  def assertPhaseRepeat(self, phase_record):
5✔
842
    self.assertIs(phase_descriptor.PhaseResult.REPEAT,
5✔
843
                  phase_record.result.phase_result)
844

845
  def assertPhaseSkip(self, phase_record):
5✔
846
    self.assertIs(phase_descriptor.PhaseResult.SKIP,
5✔
847
                  phase_record.result.phase_result)
848

849
  def assertPhaseStop(self, phase_record):
5✔
850
    self.assertIs(phase_descriptor.PhaseResult.STOP,
5✔
851
                  phase_record.result.phase_result)
852

853
  def assertPhaseError(self, phase_record, exc_type=None):
5✔
854
    self.assertTrue(phase_record.result.raised_exception,
5✔
855
                    'Phase did not raise an exception')
856
    if exc_type:
5✔
857
      self.assertIsInstance(
5✔
858
          phase_record.result.phase_result.exc_val, exc_type,
859
          'Raised exception %r is not a subclass of %r' %
860
          (phase_record.result.phase_result, exc_type))
861

862
  def assertPhaseTimeout(self, phase_record):
5✔
863
    self.assertTrue(phase_record.result.is_timeout)
5✔
864

865
  def assertPhaseOutcomePass(self, phase_record):
5✔
866
    self.assertIs(
5✔
867
        test_record.PhaseOutcome.PASS,
868
        phase_record.outcome,
869
        msg='\n\n{}'.format(
870
            text.StringFromPhaseRecord(
871
                phase_record,
872
                only_failures=True,
873
                maximum_num_measurements=_MAXIMUM_NUM_MEASUREMENTS_PER_PHASE)))
874

875
  def assertPhaseOutcomeFail(self, phase_record):
5✔
876
    msg = None
5✔
877
    if phase_record.result.raised_exception is not None:
5✔
878
      msg = ('The following exception was raised: '
5✔
879
             f'{phase_record.result.phase_result}.')
880
    self.assertIs(test_record.PhaseOutcome.FAIL, phase_record.outcome, msg=msg)
5✔
881

882
  def assertPhaseOutcomeSkip(self, phase_record):
5✔
883
    self.assertIs(test_record.PhaseOutcome.SKIP, phase_record.outcome)
5✔
884

885
  def assertPhaseOutcomeError(self, phase_record):
5✔
886
    self.assertIs(test_record.PhaseOutcome.ERROR, phase_record.outcome)
5✔
887

888
  def assertPhasesOutcomeByName(self,
5✔
889
                                expected_outcome: test_record.PhaseOutcome,
890
                                test_rec: test_record.TestRecord,
891
                                *phase_names: Text):
892
    errors: List[Text] = []
5✔
893
    for phase_rec in filter_phases_by_names(test_rec.phases, *phase_names):
5✔
894
      if phase_rec.outcome is not expected_outcome:
5✔
895
        errors.append('Phase "{}" outcome: {}'.format(phase_rec.name,
×
896
                                                      phase_rec.outcome))
897
    self.assertFalse(
5✔
898
        errors,
899
        msg='Expected phases don\'t all have outcome {}: {}'.format(
900
            expected_outcome.name, errors))
901

902
  def assertPhasesNotRun(self, test_rec, *phase_names):
5✔
903
    phases = list(filter_phases_by_names(test_rec.phases, *phase_names))
5✔
904
    self.assertFalse(phases)
5✔
905

906
  ##### Measurement Assertions #####
907

908
  def assertNotMeasured(self, phase_or_test_record, measurement):
5✔
909

910
    def _check_phase(phase_record, strict=False):
5✔
911
      if strict:
5✔
912
        self.assertIn(measurement, phase_record.measurements)
5✔
913
      if measurement in phase_record.measurements:
5✔
914
        self.assertFalse(
5✔
915
            phase_record.measurements[measurement].measured_value.is_value_set,
916
            'Measurement %s unexpectedly set' % measurement)
917
        self.assertIs(measurements.Outcome.UNSET,
5✔
918
                      phase_record.measurements[measurement].outcome)
919

920
    if isinstance(phase_or_test_record, test_record.PhaseRecord):
5✔
921
      _check_phase(phase_or_test_record, True)
5✔
922
    else:
923
      # Check *all* phases (not *any* like _assert_phase_or_test_record).
924
      for phase_record in phase_or_test_record.phases:
5✔
925
        _check_phase(phase_record)
5✔
926

927
  @_assert_phase_or_test_record
5✔
928
  def assertMeasured(
5✔
929
      self, phase_record, measurement, value=mock.ANY, delta=None
930
  ):
931
    self.assertTrue(
5✔
932
        phase_record.measurements[measurement].measured_value.is_value_set,
933
        'Measurement %s not set' % measurement,
934
    )
935
    if value is not mock.ANY:
5✔
936
      if isinstance(value, float):
5✔
NEW
UNCOV
937
        self.assertAlmostEqual(
×
938
            value,
939
            phase_record.measurements[measurement].measured_value.value,
940
            delta=delta,
941
            msg=(
942
                'Measurement %s has wrong value: expected %s, got %s,'
943
                ' tolerance %s'
944
            )
945
            % (
946
                measurement,
947
                value,
948
                phase_record.measurements[measurement].measured_value.value,
949
                delta,
950
            ),
951
        )
952
      else:
953
        if delta is not None:
5✔
NEW
954
          raise ValueError(
×
955
              'Delta is not supported when expected value is not a float.'
956
          )
957
        self.assertEqual(
5✔
958
            value,
959
            phase_record.measurements[measurement].measured_value.value,
960
            'Measurement %s has wrong value: expected %s, got %s'
961
            % (
962
                measurement,
963
                value,
964
                phase_record.measurements[measurement].measured_value.value,
965
            ),
966
        )
967

968
  @_assert_phase_or_test_record
5✔
969
  def assertMeasurementPass(self, phase_record, measurement, value=mock.ANY):
5✔
970
    self.assertMeasured(phase_record, measurement, value)
5✔
971
    self.assertIs(measurements.Outcome.PASS,
5✔
972
                  phase_record.measurements[measurement].outcome)
973

974
  @_assert_phase_or_test_record
5✔
975
  def assertMeasurementFail(self, phase_record, measurement, value=mock.ANY):
5✔
976
    self.assertMeasured(phase_record, measurement, value)
5✔
977
    self.assertIs(measurements.Outcome.FAIL,
5✔
978
                  phase_record.measurements[measurement].outcome)
979

980
  @_assert_phase_or_test_record
5✔
981
  def assertMeasurementMarginal(
5✔
982
      self, phase_record, measurement, value=mock.ANY
983
  ):
984
    self.assertMeasured(phase_record, measurement, value)
5✔
985
    self.assertTrue(phase_record.measurements[measurement].marginal)
5✔
986

987
  @_assert_phase_or_test_record
5✔
988
  def assertMeasurementNotMarginal(
5✔
989
      self, phase_record, measurement, value=mock.ANY
990
  ):
991
    self.assertMeasured(phase_record, measurement, value)
5✔
992
    self.assertFalse(phase_record.measurements[measurement].marginal)
5✔
993

994
  @_assert_phase_or_test_record
5✔
995
  def assertAttachment(self,
5✔
996
                       phase_record,
997
                       attachment_name,
998
                       expected_contents=mock.ANY):
UNCOV
999
    self.assertIn(attachment_name, phase_record.attachments,
×
1000
                  'Attachment {} not attached.'.format(attachment_name))
UNCOV
1001
    if expected_contents is not mock.ANY:
×
UNCOV
1002
      data = phase_record.attachments[attachment_name].data
×
UNCOV
1003
      self.assertEqual(
×
1004
          expected_contents, data,
1005
          'Attachment {} has wrong value: expected {}, got {}'.format(
1006
              attachment_name, expected_contents, data))
1007

1008
  def get_diagnoses_store(self):
5✔
1009
    self.assertIsNotNone(self.last_test_state)
5✔
1010
    assert self.last_test_state is not None
5✔
1011
    return self.last_test_state.diagnoses_manager.store
5✔
1012

1013
  @classmethod
5✔
1014
  def set_profile_dir(cls, profile_dir: pathlib.Path) -> None:
5✔
1015
    """Sets the output directory for profiling, and enables profiling.
1016

1017
    WARNING: This method is provided for debugging only, and may be removed in
1018
    a future update. The test.py module currently runs all tests in a thread,
1019
    which cannot be profiled without this feature.
1020
    Call this from setUpClass to enable profiling.
1021

1022
    Args:
1023
      profile_dir: The directory to place the profile file in. See
1024
        get_profile_filepath for details.
1025
    """
1026
    cls._profile_output_dir = profile_dir
5✔
1027
    try:
5✔
1028
      # Remove file if it already exists. This has to be done in setUpClass
1029
      # because we want to clear it before the test case starts, but to be
1030
      # updated as individual test* methods are run.
1031
      os.remove(cls.get_profile_filepath())
5✔
1032
    except FileNotFoundError:
5✔
1033
      pass
5✔
1034

1035
  @classmethod
5✔
1036
  def get_profile_filepath(cls) -> Optional[pathlib.Path]:
5✔
1037
    """Returns profile filepath if profile_output_dir is set, else None.
1038

1039
    The output filename is {module}.{test_case}.pstats.
1040
    """
1041
    if cls._profile_output_dir is not None:
5✔
1042
      return pathlib.Path(cls._profile_output_dir,
5✔
1043
                          f'{__name__.split(".")[-1]}.{cls.__name__}.pstats')
1044
    return None
5✔
1045

1046

1047
def get_flattened_phases(
5✔
1048
    node_collections: Iterable[
1049
        Union[phase_nodes.PhaseNode, phase_collections.PhaseCollectionNode]
1050
    ],
1051
) -> Sequence[phase_nodes.PhaseNode]:
1052
  """Flattens nested sequences of nodes into phase descriptors."""
1053
  phases = []
5✔
1054
  phases_or_phase_groups = phase_collections.flatten(node_collections)
5✔
1055
  for phase_or_phase_group in phases_or_phase_groups:
5✔
1056
    if isinstance(phase_or_phase_group, phase_collections.PhaseCollectionNode):
5✔
1057
      phases.extend(phase_or_phase_group.all_phases())
5✔
1058
    elif isinstance(phase_or_phase_group, phase_nodes.PhaseNode):
5✔
1059
      phases.append(phase_or_phase_group)
5✔
1060
    else:
UNCOV
1061
      raise TypeError('Not a phase node or a phase collection node.')
×
1062
  return phases
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc