• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 8849617014

26 Apr 2024 02:11PM UTC coverage: 89.827% (-6.0%) from 95.841%
8849617014

push

github

web-flow
Fixed infinite recursion in representation invariants with method calls (#1031)

15 of 17 new or added lines in 1 file covered. (88.24%)

165 existing lines in 15 files now uncovered.

2746 of 3057 relevant lines covered (89.83%)

8.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.06
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
5✔
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
import inspect
10✔
14
import logging
10✔
15
import sys
10✔
16
import typing
10✔
17
from types import CodeType, FunctionType, ModuleType
10✔
18
from typing import Any, Callable, List, Optional, Set, Tuple, TypeVar, Union, overload
10✔
19

20
import wrapt
10✔
21
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
10✔
22

23
# Configuration options
24

25
ENABLE_CONTRACT_CHECKING = True
10✔
26
"""
6✔
27
Set to True to enable contract checking.
28
"""
29

30
DEBUG_CONTRACTS = False
10✔
31
"""
6✔
32
Set to True to display debugging messages when checking contracts.
33
"""
34

35
RENAME_MAIN_TO_PYDEV_UMD = True
10✔
36
"""
6✔
37
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
38
In most cases you should not need to change this!
39
"""
40

41
_PYDEV_UMD_NAME = "pydev_umd"
10✔
42

43

44
_DEFAULT_MAX_VALUE_LENGTH = 30
10✔
45
FUNCTION_RETURN_VALUE = "$return_value"
10✔
46

47

48
class PyTAContractError(Exception):
10✔
49
    """Error raised when a PyTA contract assertion is violated."""
5✔
50

51

52
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
10✔
53
    """Automatically check contracts for all functions and classes in the given modules.
54

55
    By default (when called with no arguments), the current module is used.
56

57
    Args:
58
        *mod_names: The names of modules to check contracts for. These modules must have been
59
            previously imported.
60
        decorate_main: True if the module being run (where __name__ == '__main__') should
61
            have contracts checked.
62
    """
63
    if not ENABLE_CONTRACT_CHECKING:
10✔
64
        return
×
65

66
    modules = []
10✔
67
    if decorate_main:
10✔
68
        mod_names = mod_names + ("__main__",)
×
69

70
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
71
        # with the "Run in Python Console" action.
72
        if RENAME_MAIN_TO_PYDEV_UMD:
×
73
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
74

75
    for module_name in mod_names:
10✔
76
        modules.append(sys.modules.get(module_name, None))
10✔
77

78
    for module in modules:
10✔
79
        if not module:
10✔
80
            # Module name was passed in incorrectly.
81
            continue
×
82
        for name, value in inspect.getmembers(module):
10✔
83
            if inspect.isfunction(value) or inspect.isclass(value):
10✔
84
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
10✔
85

86

87
@wrapt.decorator
10✔
88
def _enable_function_contracts(wrapped, instance, args, kwargs):
10✔
89
    """A decorator that enables checking contracts for a function."""
90
    try:
10✔
91
        if instance is not None and inspect.isclass(instance):
10✔
92
            # This is a class method, so there is no instance.
93
            return _check_function_contracts(wrapped, None, args, kwargs)
10✔
94
        else:
95
            return _check_function_contracts(wrapped, instance, args, kwargs)
10✔
96
    except PyTAContractError as e:
10✔
97
        raise AssertionError(str(e)) from None
10✔
98

99

100
# Wildcard Type Variable
101
Class = TypeVar("Class", bound=type)
10✔
102

103

104
@overload
10✔
105
def check_contracts(
10✔
106
    func: FunctionType, module_names: Optional[Set[str]] = None
107
) -> FunctionType: ...
108

109

110
@overload
10✔
111
def check_contracts(func: Class, module_names: Optional[Set[str]] = None) -> Class: ...
10✔
112

113

114
def check_contracts(
10✔
115
    func_or_class: Union[Class, FunctionType], module_names: Optional[Set[str]] = None
116
) -> Union[Class, FunctionType]:
117
    """A decorator to enable contract checking for a function or class.
118

119
    When used with a class, all methods defined within the class have contract checking enabled.
120
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
121

122
    Example:
123

124
        >>> from python_ta.contracts import check_contracts
125
        >>> @check_contracts
126
        ... def divide(x: int, y: int) -> int:
127
        ...     \"\"\"Return x // y.
128
        ...
129
        ...     Preconditions:
130
        ...        - y != 0
131
        ...     \"\"\"
132
        ...     return x // y
133
    """
134
    if not ENABLE_CONTRACT_CHECKING:
10✔
135
        return func_or_class
10✔
136

137
    if module_names is not None and func_or_class.__module__ not in module_names:
10✔
138
        _debug(
10✔
139
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
140
        )
141
        return func_or_class
10✔
142
    elif inspect.isroutine(func_or_class):
10✔
143
        return _enable_function_contracts(func_or_class)
10✔
144
    elif inspect.isclass(func_or_class):
10✔
145
        add_class_invariants(func_or_class)
10✔
146
        return func_or_class
10✔
147
    else:
148
        # Default action
UNCOV
149
        return func_or_class
×
150

151

152
def add_class_invariants(klass: type) -> None:
10✔
153
    """Modify the given class to check representation invariants and method contracts."""
154
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
10✔
155
        # This means the class has already been decorated
UNCOV
156
        return
×
157

158
    _set_invariants(klass)
10✔
159

160
    klass_mod = _get_module(klass)
10✔
161
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
10✔
162

163
    def new_setattr(self: klass, name: str, value: Any) -> None:
10✔
164
        """Set the value of the given attribute on self to the given value.
165

166
        Check representation invariants for this class when not within an instance method of the class.
167
        """
168
        if not ENABLE_CONTRACT_CHECKING:
10✔
169
            super(klass, self).__setattr__(name, value)
10✔
170
            return
10✔
171

172
        nonlocal cls_annotations
173
        if cls_annotations is None:
10✔
174
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
10✔
175

176
        if name in cls_annotations:
10✔
177
            try:
10✔
178
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
10✔
179
                check_type(
10✔
180
                    value,
181
                    cls_annotations[name],
182
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
183
                )
184
            except TypeCheckError:
10✔
185
                raise AssertionError(
10✔
186
                    f"Value {_display_value(value)} did not match type annotation for attribute "
187
                    f"{name}: {_display_annotation(cls_annotations[name])}"
188
                ) from None
189
        original_attr_value_exists = False
10✔
190
        original_attr_value = None
10✔
191
        if hasattr(super(klass, self), name):
10✔
UNCOV
192
            original_attr_value_exists = True
×
UNCOV
193
            original_attr_value = super(klass, self).__getattribute__(name)
×
194
        super(klass, self).__setattr__(name, value)
10✔
195
        frame_locals = inspect.currentframe().f_back.f_locals
10✔
196
        if self is not frame_locals.get("self"):
10✔
197
            # Only validating if the attribute is not being set in a instance/class method
198
            if klass_mod is not None:
10✔
199
                try:
10✔
200
                    _check_invariants(self, klass, klass_mod.__dict__)
10✔
201
                except PyTAContractError as e:
10✔
202
                    if original_attr_value_exists:
10✔
UNCOV
203
                        super(klass, self).__setattr__(name, original_attr_value)
×
204
                    else:
205
                        super(klass, self).__delattr__(name)
10✔
206
                    raise AssertionError(str(e)) from None
10✔
207

208
    for attr, value in klass.__dict__.items():
10✔
209
        if inspect.isroutine(value):
10✔
210
            if isinstance(value, (staticmethod, classmethod)):
10✔
211
                # Don't check rep invariants for staticmethod and classmethod
212
                setattr(klass, attr, check_contracts(value))
10✔
213
            else:
214
                setattr(klass, attr, _instance_method_wrapper(value, klass))
10✔
215

216
    klass.__setattr__ = new_setattr
10✔
217

218

219
def _check_function_contracts(wrapped, instance, args, kwargs):
10✔
220
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
10✔
221
    if instance is not None:
10✔
222
        klass_mod = _get_module(type(instance))
10✔
223
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
10✔
224
    else:
225
        annotations = typing.get_type_hints(wrapped)
10✔
226
    args_with_self = args if instance is None else (instance,) + args
10✔
227

228
    # Check function parameter types
229
    for arg, param in zip(args_with_self, params):
10✔
230
        if param in annotations:
10✔
231
            try:
10✔
232
                _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
10✔
233
                check_type_strict(param, arg, annotations[param])
10✔
234
            except (TypeError, TypeCheckError):
10✔
235
                additional_suggestions = _get_argument_suggestions(arg, annotations[param])
10✔
236

237
                raise PyTAContractError(
10✔
238
                    f"{wrapped.__name__} argument {_display_value(arg)} did not match type "
239
                    f"annotation for parameter {param}: {_display_annotation(annotations[param])}"
240
                    + (f"\n{additional_suggestions}" if additional_suggestions else "")
241
                )
242

243
    function_locals = dict(zip(params, args_with_self))
10✔
244

245
    # Check bounded function
246
    if hasattr(wrapped, "__self__"):
10✔
247
        target = wrapped.__func__
10✔
248
    else:
249
        target = wrapped
10✔
250

251
    # Check function preconditions
252
    if not hasattr(target, "__preconditions__"):
10✔
253
        target.__preconditions__: List[Tuple[str, CodeType]] = []
10✔
254
        preconditions = parse_assertions(wrapped)
10✔
255
        for precondition in preconditions:
10✔
256
            try:
10✔
257
                compiled = compile(precondition, "<string>", "eval")
10✔
258
            except:
10✔
259
                _debug(
10✔
260
                    f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
261
                )
262
                continue
10✔
263
            target.__preconditions__.append((precondition, compiled))
10✔
264

265
    if ENABLE_CONTRACT_CHECKING:
10✔
266
        _check_assertions(wrapped, function_locals)
10✔
267

268
    # Check return type
269
    r = wrapped(*args, **kwargs)
10✔
270
    if "return" in annotations:
10✔
271
        return_type = annotations["return"]
10✔
272
        try:
10✔
273
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
10✔
274
            check_type_strict("return", r, return_type)
10✔
275
        except (TypeError, TypeCheckError):
10✔
276
            raise PyTAContractError(
10✔
277
                f"{wrapped.__name__}'s return value {_display_value(r)} did not match "
278
                f"return type annotation {_display_annotation(return_type)}"
279
            )
280

281
    # Check function postconditions
282
    if not hasattr(target, "__postconditions__"):
10✔
283
        target.__postconditions__: List[Tuple[str, CodeType, str]] = []
10✔
284
        return_val_var_name = _get_legal_return_val_var_name(
10✔
285
            {**wrapped.__globals__, **function_locals}
286
        )
287
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
10✔
288
        for postcondition in postconditions:
10✔
289
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
10✔
290
            try:
10✔
291
                compiled = compile(assertion, "<string>", "eval")
10✔
UNCOV
292
            except:
×
UNCOV
293
                _debug(
×
294
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
295
                )
UNCOV
296
                continue
×
297
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
10✔
298

299
    if ENABLE_CONTRACT_CHECKING:
10✔
300
        _check_assertions(
10✔
301
            wrapped,
302
            function_locals,
303
            function_return_val=r,
304
            condition_type="postcondition",
305
        )
306

307
    return r
10✔
308

309

310
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
10✔
311
    """Ensure that ``value`` matches ``expected_type``.
312

313
    Differentiates between:
314
        - float vs. int
315
        - bool vs. int
316
    """
317
    if ENABLE_CONTRACT_CHECKING:
10✔
318
        if (type(value) is int and expected_type is float) or (
10✔
319
            type(value) is bool and expected_type is int
320
        ):
321
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
322
        check_type(
10✔
323
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
324
        )
325

326

327
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
10✔
328
    """Returns potential suggestions for the given arg and its annotation"""
329
    try:
10✔
330
        if isinstance(arg, type) and issubclass(arg, annotation):
10✔
331
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
10✔
332
    except TypeError:
10✔
333
        pass
10✔
334

335
    return ""
10✔
336

337

338
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
10✔
339
    @wrapt.decorator
10✔
340
    def wrapper(wrapped, instance, args, kwargs):
10✔
341
        try:
10✔
342
            r = _check_function_contracts(wrapped, instance, args, kwargs)
10✔
343
            if _instance_init_in_callstack(instance):
10✔
344
                return r
10✔
345
            _check_class_type_annotations(klass, instance)
10✔
346
            klass_mod = _get_module(klass)
10✔
347
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
10✔
348
                _check_invariants(instance, klass, klass_mod.__dict__)
10✔
349
        except PyTAContractError as e:
10✔
350
            raise AssertionError(str(e)) from None
10✔
351
        else:
352
            return r
10✔
353

354
    return wrapper(wrapped)
10✔
355

356

357
def _instance_init_in_callstack(instance: Any) -> bool:
10✔
358
    """Return whether instance's init is part of the current callstack
359

360
    Note: due to the nature of the check, externally defined __init__ functions with
361
    'self' defined as the first parameter may pass this check.
362
    """
363
    frame = inspect.currentframe().f_back
10✔
364
    while frame:
10✔
365
        frame_context_name = inspect.getframeinfo(frame).function
10✔
366
        frame_context_self = frame.f_locals.get("self")
10✔
367
        frame_context_vars = frame.f_code.co_varnames
10✔
368
        if (
10✔
369
            frame_context_name == "__init__"
370
            and frame_context_self is instance
371
            and frame_context_vars[0] == "self"
372
        ):
373
            return True
10✔
374
        frame = frame.f_back
10✔
375
    return False
10✔
376

377

378
def _check_class_type_annotations(klass: type, instance: Any) -> None:
10✔
379
    """Check that the type annotations for the class still hold.
380

381
    Precondition:
382
        - isinstance(instance, klass)
383
    """
384
    klass_mod = _get_module(klass)
10✔
385
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
10✔
386

387
    for attr, annotation in cls_annotations.items():
10✔
388
        value = getattr(instance, attr)
10✔
389
        try:
10✔
390
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
10✔
391
            check_type(
10✔
392
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
393
            )
394
        except TypeCheckError:
10✔
395
            raise AssertionError(
10✔
396
                f"{_display_value(value)} did not match type annotation for attribute {attr}: "
397
                f"{_display_annotation(annotation)}"
398
            )
399

400

401
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
10✔
402
    """Check that the representation invariants for the instance are satisfied."""
403
    if hasattr(instance, "__pyta_currently_checking"):
10✔
404
        # If already checking invariants for this instance, skip to avoid infinite recursion
405
        return
10✔
406

407
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
10✔
408

409
    rep_invariants = getattr(klass, "__representation_invariants__", set())
10✔
410

411
    try:
10✔
412
        for invariant, compiled in rep_invariants:
10✔
413
            try:
10✔
414
                _debug(
10✔
415
                    "Checking representation invariant for "
416
                    f"{instance.__class__.__qualname__}: {invariant}"
417
                )
418
                check = eval(compiled, {**global_scope, "self": instance})
10✔
419
            except AssertionError as e:
10✔
420
                raise AssertionError(str(e)) from None
10✔
NEW
421
            except:
×
NEW
422
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
423
            else:
424
                if not check:
10✔
425
                    curr_attributes = ", ".join(
10✔
426
                        f"{k}: {_display_value(v)}"
427
                        for k, v in vars(instance).items()
428
                        if k != "__pyta_currently_checking"
429
                    )
430

431
                    curr_attributes = "{" + curr_attributes + "}"
10✔
432

433
                    raise PyTAContractError(
10✔
434
                        f'"{instance.__class__.__name__}" representation invariant "{invariant}" was violated for'
435
                        f" instance attributes {curr_attributes}"
436
                    )
437

438
    finally:
439
        delattr(instance, "__pyta_currently_checking")
10✔
440

441

442
def _get_legal_return_val_var_name(var_dict: dict) -> str:
10✔
443
    """
444
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
445
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
446
    postconditions.
447
    """
448
    legal_var_name = "__function_return_value__"
10✔
449

450
    while legal_var_name in var_dict:
10✔
UNCOV
451
        legal_var_name += "_"
×
452

453
    return legal_var_name
10✔
454

455

456
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
10✔
457
    """
458
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
459
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
460

461
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
462
    """
463

464
    if FUNCTION_RETURN_VALUE in assertion:
10✔
465
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
10✔
466
    return assertion
10✔
467

468

469
def _check_assertions(
10✔
470
    wrapped: Callable[..., Any],
471
    function_locals: dict,
472
    condition_type: str = "precondition",
473
    function_return_val: Any = None,
474
) -> None:
475
    """Check that the given assertions are still satisfied."""
476
    # Check bounded function
477
    if hasattr(wrapped, "__self__"):
10✔
478
        target = wrapped.__func__
10✔
479
    else:
480
        target = wrapped
10✔
481
    assertions = []
10✔
482
    if condition_type == "precondition":
10✔
483
        assertions = target.__preconditions__
10✔
484
    elif condition_type == "postcondition":
10✔
485
        assertions = target.__postconditions__
10✔
486
    for assertion_str, compiled, *return_val_var_name in assertions:
10✔
487
        return_val_dict = {}
10✔
488
        if condition_type == "postcondition":
10✔
489
            return_val_dict = {return_val_var_name[0]: function_return_val}
10✔
490
        try:
10✔
491
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
10✔
492
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
10✔
493
        except AssertionError as e:
10✔
494
            raise AssertionError(str(e)) from None
10✔
UNCOV
495
        except:
×
UNCOV
496
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
497
        else:
498
            if not check:
10✔
499
                arg_string = ", ".join(
10✔
500
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
501
                )
502
                arg_string = "{" + arg_string + "}"
10✔
503

504
                return_val_string = ""
10✔
505

506
                if condition_type == "postcondition":
10✔
507
                    return_val_string = f"and return value {function_return_val}"
10✔
508
                raise PyTAContractError(
10✔
509
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
510
                    f"violated for arguments {arg_string} {return_val_string}"
511
                )
512

513

514
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> List[str]:
10✔
515
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
516

517
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
518

519
    Currently only supports two forms:
520

521
    1. A single line of the form "<parse_token>: <cond>"
522
    2. A group of lines starting with "<parse_token>s:", where each subsequent
523
       line is of the form "- <cond>". Each line is considered a separate condition.
524
       The lines can be separated by blank lines, but no other text.
525
    """
526
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
10✔
527
        # Check if obj is an astroid node
528
        docstring = obj.doc_node.value
10✔
529
    else:
530
        docstring = getattr(obj, "__doc__") or ""
10✔
531
    lines = [line.strip() for line in docstring.split("\n")]
10✔
532
    assertion_lines = [
10✔
533
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
534
    ]
535

536
    if assertion_lines == []:
10✔
537
        return []
10✔
538

539
    first = assertion_lines[0]
10✔
540

541
    if lines[first].startswith(parse_token + ":"):
10✔
542
        return [lines[first][len(parse_token + ":") :].strip()]
10✔
543
    elif lines[first].startswith(parse_token + "s:"):
10✔
544
        assertions = []
10✔
545
        for line in lines[first + 1 :]:
10✔
546
            if line.startswith("-"):
10✔
547
                assertion = line[1:].strip()
10✔
548
                if hasattr(obj, "__qualname__"):
10✔
549
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
10✔
550
                assertions.append(assertion)
10✔
551
            elif line != "":
10✔
UNCOV
552
                break
×
553
        return assertions
10✔
554
    else:
UNCOV
555
        return []
×
556

557

558
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
10✔
559
    """Return a human-friendly representation of the given value.
560

561
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
562

563
    Preconditions:
564
        - max_length >= 5
565
    """
566
    s = repr(value)
10✔
567
    if not DEBUG_CONTRACTS and len(s) > max_length:
10✔
UNCOV
568
        i = (max_length - 3) // 2
×
UNCOV
569
        return s[:i] + "..." + s[-i:]
×
570
    else:
571
        return s
10✔
572

573

574
def _display_annotation(annotation: Any) -> str:
10✔
575
    """Return a human-friendly representation of the given type annotation.
576

577
    >>> _display_annotation(int)
578
    'int'
579
    >>> _display_annotation(list[int])
580
    'list[int]'
581
    >>> from typing import List
582
    >>> _display_annotation(List[int])
583
    'typing.List[int]'
584
    """
585
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
10✔
UNCOV
586
        return "None"
×
587
    if hasattr(annotation, "__origin__"):  # Generic type annotations
10✔
588
        return repr(annotation)
10✔
589
    elif hasattr(annotation, "__name__"):
10✔
590
        return annotation.__name__
10✔
591
    else:
UNCOV
592
        return repr(annotation)
×
593

594

595
def _get_module(obj: Any) -> ModuleType:
10✔
596
    """Return the module where obj was defined (normally obj.__module__).
597

598
    NOTE: this function defines a special case when using PyCharm and the file
599
    defining the object is "Run in Python Console". In this case, the pydevd runner
600
    renames the '__main__' module to 'pydev_umd', and so we need to access that
601
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
602
    to False.
603
    """
604
    module_name = obj.__module__
10✔
605
    module = sys.modules[module_name]
10✔
606

607
    if (
10✔
608
        module_name != "__main__"
609
        or not RENAME_MAIN_TO_PYDEV_UMD
610
        or _PYDEV_UMD_NAME not in sys.modules
611
    ):
612
        return module
10✔
613

614
    # Get a function/class name to check whether it is defined in the module
UNCOV
615
    if isinstance(obj, (FunctionType, type)):
×
UNCOV
616
        name = obj.__name__
×
617
    else:
618
        # For any other type of object, be conservative and just return the module
UNCOV
619
        return module
×
620

UNCOV
621
    if name in vars(module):
×
UNCOV
622
        return module
×
623
    else:
UNCOV
624
        return sys.modules[_PYDEV_UMD_NAME]
×
625

626

627
def _debug(msg: str) -> None:
10✔
628
    """Display a debugging message.
629

630
    Do nothing if DEBUG_CONTRACTS is False.
631
    """
632
    if not DEBUG_CONTRACTS:
10✔
633
        return
10✔
634
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
10✔
635
    logging.debug(msg)
10✔
636

637

638
def _set_invariants(klass: type) -> None:
10✔
639
    """Retrieve and set the representation invariants of this class"""
640
    # Update representation invariants from this class' docstring and those of its superclasses.
641
    rep_invariants: List[Tuple[str, CodeType]] = []
10✔
642

643
    # Iterate over all inherited classes except builtins
644
    for cls in reversed(klass.__mro__):
10✔
645
        if "__representation_invariants__" in cls.__dict__:
10✔
646
            rep_invariants.extend(cls.__representation_invariants__)
10✔
647
        elif cls.__module__ != "builtins":
10✔
648
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
10✔
649
            # Try compiling assertions
650
            for assertion in assertions:
10✔
651
                try:
10✔
652
                    compiled = compile(assertion, "<string>", "eval")
10✔
UNCOV
653
                except:
×
UNCOV
654
                    _debug(
×
655
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
656
                    )
UNCOV
657
                    continue
×
658
                rep_invariants.append((assertion, compiled))
10✔
659

660
    setattr(klass, "__representation_invariants__", rep_invariants)
10✔
661

662

663
def validate_invariants(obj: object) -> None:
10✔
664
    """Check that the representation invariants of obj are satisfied."""
665
    klass = obj.__class__
10✔
666
    klass_mod = _get_module(klass)
10✔
667

668
    try:
10✔
669
        _check_invariants(obj, klass, klass_mod.__dict__)
10✔
670
    except PyTAContractError as e:
10✔
671
        raise AssertionError(str(e)) from None
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc