• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 19638573471

20 Nov 2025 06:18AM UTC coverage: 93.944% (+0.02%) from 93.928%
19638573471

Pull #1266

github

web-flow
Merge 824a0a27e into 7e38b378a
Pull Request #1266: Added optional arguments to check_contracts decorator to disable checks

40 of 41 new or added lines in 1 file covered. (97.56%)

23 existing lines in 1 file now uncovered.

3537 of 3765 relevant lines covered (93.94%)

17.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.68
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
from __future__ import annotations
20✔
14

15
import inspect
20✔
16
import logging
20✔
17
import re
20✔
18
import sys
20✔
19
import typing
20✔
20
from types import CodeType, FunctionType, ModuleType
20✔
21
from typing import (
20✔
22
    Any,
23
    Callable,
24
    Optional,
25
    TypeVar,
26
    Union,
27
    get_args,
28
    get_origin,
29
    overload,
30
)
31

32
import wrapt
20✔
33
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
20✔
34

35
# Configuration options
36

37
ENABLE_CONTRACT_CHECKING = True
20✔
38
"""
16✔
39
Set to True to enable contract checking.
40
"""
41

42
DEBUG_CONTRACTS = False
20✔
43
"""
16✔
44
Set to True to display debugging messages when checking contracts.
45
"""
46

47
RENAME_MAIN_TO_PYDEV_UMD = True
20✔
48
"""
16✔
49
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
50
In most cases you should not need to change this!
51
"""
52

53
STRICT_NUMERIC_TYPES = True
20✔
54
"""
16✔
55
Set to False to allow more specific numeric types to be accepted by more general type annotations.
56
"""
57

58
_PYDEV_UMD_NAME = "pydev_umd"
20✔
59

60

61
_DEFAULT_MAX_VALUE_LENGTH = 30
20✔
62
FUNCTION_RETURN_VALUE = "$return_value"
20✔
63

64

65
class PyTAContractError(Exception):
20✔
66
    """Error raised when a PyTA contract assertion is violated."""
67

68

69
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
20✔
70
    """Automatically check contracts for all functions and classes in the given modules.
71

72
    By default (when called with no arguments), the current module is used.
73

74
    Args:
75
        *mod_names: The names of modules to check contracts for. These modules must have been
76
            previously imported.
77
        decorate_main: True if the module being run (where __name__ == '__main__') should
78
            have contracts checked.
79
    """
80
    if not ENABLE_CONTRACT_CHECKING:
20✔
81
        return
×
82

83
    modules = []
20✔
84
    if decorate_main:
20✔
85
        mod_names = mod_names + ("__main__",)
×
86

87
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
88
        # with the "Run in Python Console" action.
89
        if RENAME_MAIN_TO_PYDEV_UMD:
×
90
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
91

92
    for module_name in mod_names:
20✔
93
        modules.append(sys.modules.get(module_name, None))
20✔
94

95
    for module in modules:
20✔
96
        if not module:
20✔
97
            # Module name was passed in incorrectly.
98
            continue
×
99
        for name, value in inspect.getmembers(module):
20✔
100
            if inspect.isfunction(value) or inspect.isclass(value):
20✔
101
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
20✔
102

103

104
@wrapt.decorator
20✔
105
def _enable_function_contracts(wrapped, instance, args, kwargs):
20✔
106
    """A decorator that enables checking contracts for a function."""
107
    try:
20✔
108
        if instance is not None and inspect.isclass(instance):
20✔
109
            # This is a class method, so there is no instance.
110
            return _check_function_contracts(wrapped, None, args, kwargs)
20✔
111
        else:
112
            return _check_function_contracts(wrapped, instance, args, kwargs)
20✔
113
    except PyTAContractError as e:
20✔
114
        raise AssertionError(str(e)) from None
20✔
115

116

117
# Wildcard Type Variable
118
Class = TypeVar("Class", bound=type)
20✔
119

120

121
@overload
122
def check_contracts(
123
    func: FunctionType,
124
    module_names: Optional[set[str]] = None,
125
    argument_types=True,
126
    return_type=True,
127
    preconditions=True,
128
    postconditions=True,
129
) -> FunctionType: ...
130

131

132
@overload
133
def check_contracts(
134
    func: Class,
135
    module_names: Optional[set[str]] = None,
136
    argument_types=True,
137
    return_type=True,
138
    preconditions=True,
139
    postconditions=True,
140
) -> Class: ...
141

142

143
def check_contracts(
20✔
144
    func_or_class: Union[Class, FunctionType] = None,
145
    *,
146
    module_names: Optional[set[str]] = None,
147
    argument_types=True,
148
    return_type=True,
149
    preconditions=True,
150
    postconditions=True,
151
) -> Union[Class, FunctionType]:
152
    """A decorator to enable contract checking for a function or class.
153

154
    When used with a class, all methods defined within the class have contract checking enabled.
155
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
156

157
    Example:
158

159
        >>> from python_ta.contracts import check_contracts
160
        >>> @check_contracts
161
        ... def divide(x: int, y: int) -> int:
162
        ...     \"\"\"Return x // y.
163
        ...
164
        ...     Preconditions:
165
        ...        - y != 0
166
        ...     \"\"\"
167
        ...     return x // y
168
    """
169
    # Optional Arguments passed to the decorator
170
    if func_or_class is None:
20✔
171
        return wrapt.PartialCallableObjectProxy(
20✔
172
            check_contracts,
173
            module_names=module_names,
174
            argument_types=argument_types,
175
            return_type=return_type,
176
            preconditions=preconditions,
177
            postconditions=postconditions,
178
        )
179

180
    if not ENABLE_CONTRACT_CHECKING:
20✔
181
        return func_or_class
20✔
182

183
    if module_names is not None and func_or_class.__module__ not in module_names:
20✔
184
        _debug(
20✔
185
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
186
        )
187
        return func_or_class
20✔
188
    elif inspect.isroutine(func_or_class):
20✔
189
        setattr(
20✔
190
            func_or_class,
191
            "check_contracts_options",
192
            {
193
                "argument_types": argument_types,
194
                "return_type": return_type,
195
                "preconditions": preconditions,
196
                "postconditions": postconditions,
197
            },
198
        )
199
        return _enable_function_contracts(func_or_class)
20✔
200
    elif inspect.isclass(func_or_class):
20✔
201
        add_class_invariants(func_or_class)
20✔
202
        return func_or_class
20✔
203
    else:
204
        # Default action
NEW
205
        return func_or_class
×
206

207

208
def add_class_invariants(klass: type) -> None:
20✔
209
    """Modify the given class to check representation invariants and method contracts."""
210
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
20✔
211
        # This means the class has already been decorated
UNCOV
212
        return
×
213

214
    _set_invariants(klass)
20✔
215

216
    klass_mod = _get_module(klass)
20✔
217
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
20✔
218

219
    def new_setattr(self: klass, name: str, value: Any) -> None:
20✔
220
        """Set the value of the given attribute on self to the given value.
221

222
        Check representation invariants for this class when not within an instance method of the class.
223
        """
224
        if not ENABLE_CONTRACT_CHECKING:
20✔
225
            super(klass, self).__setattr__(name, value)
20✔
226
            return
20✔
227

228
        nonlocal cls_annotations
229
        if cls_annotations is None:
20✔
230
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
231

232
        if name in cls_annotations:
20✔
233
            try:
20✔
234
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
235
                check_type(
20✔
236
                    value,
237
                    cls_annotations[name],
238
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
239
                )
240
            except TypeCheckError:
20✔
241
                raise AssertionError(
20✔
242
                    f"Value {_display_value(value)} for attribute {name} did not match expected type "
243
                    f"{_display_annotation(cls_annotations[name])}"
244
                ) from None
245
        original_attr_value_exists = False
20✔
246
        original_attr_value = None
20✔
247
        if hasattr(self, name):
20✔
248
            original_attr_value_exists = True
20✔
249
            original_attr_value = super(klass, self).__getattribute__(name)
20✔
250
        super(klass, self).__setattr__(name, value)
20✔
251
        frame_locals = inspect.currentframe().f_back.f_locals
20✔
252
        caller_self = frame_locals.get("self")
20✔
253
        if not isinstance(caller_self, type(self)):
20✔
254
            # Only validating if the attribute is not being set in a instance/class method
255
            # AND caller_self is an instance of self's type
256
            if klass_mod is not None:
20✔
257
                try:
20✔
258
                    _check_invariants(self, klass, klass_mod.__dict__)
20✔
259
                except PyTAContractError as e:
20✔
260
                    if original_attr_value_exists:
20✔
261
                        super(klass, self).__setattr__(name, original_attr_value)
20✔
262
                    else:
263
                        super(klass, self).__delattr__(name)
20✔
264
                    raise AssertionError(str(e)) from None
20✔
265
        elif caller_self is not self:
20✔
266
            # Keep track of mutations to instances that are of the same type as caller_self (and are also not `self`)
267
            # to enforce RIs on them only after the caller function returns.
268
            caller_klass = type(caller_self)
20✔
269
            if hasattr(caller_klass, "__mutated_instances__"):
20✔
270
                mutated_instances = getattr(caller_klass, "__mutated_instances__")
20✔
271
                if self not in mutated_instances:
20✔
272
                    mutated_instances.append(self)
20✔
273

274
    for attr, value in klass.__dict__.items():
20✔
275
        if inspect.isroutine(value):
20✔
276
            if isinstance(value, (staticmethod, classmethod)):
20✔
277
                # Don't check rep invariants for staticmethod and classmethod
278
                setattr(klass, attr, check_contracts(value))
20✔
279
            else:
280
                setattr(klass, attr, _instance_method_wrapper(value, klass))
20✔
281

282
    klass.__setattr__ = new_setattr
20✔
283

284

285
def _check_function_contracts(wrapped, instance, args, kwargs):
20✔
286
    options = getattr(wrapped, "check_contracts_options", {})
20✔
287
    argument_types_enabled = options.get("argument_types", True)
20✔
288
    return_type_enabled = options.get("return_type", True)
20✔
289
    preconditions_enabled = options.get("preconditions", True)
20✔
290
    postconditions_enabled = options.get("postconditions", True)
20✔
291

292
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
20✔
293
    if instance is not None:
20✔
294
        klass_mod = _get_module(type(instance))
20✔
295
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
20✔
296
    else:
297
        annotations = typing.get_type_hints(wrapped)
20✔
298
    args_with_self = args if instance is None else (instance,) + args
20✔
299

300
    if argument_types_enabled:
20✔
301
        # Check function parameter types
302
        for arg, param in zip(args_with_self, params):
20✔
303
            if param in annotations:
20✔
304
                try:
20✔
305
                    _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
20✔
306
                    if STRICT_NUMERIC_TYPES:
20✔
307
                        check_type_strict(param, arg, annotations[param])
20✔
308
                    else:
309
                        check_type(arg, annotations[param])
20✔
310
                except (TypeError, TypeCheckError):
20✔
311
                    additional_suggestions = _get_argument_suggestions(arg, annotations[param])
20✔
312

313
                    raise PyTAContractError(
20✔
314
                        f"Argument value {_display_value(arg)} for {wrapped.__name__} parameter {param} "
315
                        f"did not match expected type {_display_annotation(annotations[param])}"
316
                        + (f"\n{additional_suggestions}" if additional_suggestions else "")
317
                    )
318

319
    function_locals = dict(zip(params, args_with_self))
20✔
320

321
    # Check bounded function
322
    if hasattr(wrapped, "__self__"):
20✔
323
        target = wrapped.__func__
20✔
324
    else:
325
        target = wrapped
20✔
326

327
    if preconditions_enabled:
20✔
328
        # Check function preconditions
329
        if not hasattr(target, "__preconditions__"):
20✔
330
            target.__preconditions__: list[tuple[str, CodeType]] = []
20✔
331
            preconditions = parse_assertions(wrapped)
20✔
332
            for precondition in preconditions:
20✔
333
                try:
20✔
334
                    compiled = compile(precondition, "<string>", "eval")
20✔
335
                except:
20✔
336
                    _debug(
20✔
337
                        f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
338
                    )
339
                    continue
20✔
340
                target.__preconditions__.append((precondition, compiled))
20✔
341

342
    if ENABLE_CONTRACT_CHECKING and preconditions_enabled:
20✔
343
        _check_assertions(wrapped, function_locals)
20✔
344

345
    # Check return type
346
    r = wrapped(*args, **kwargs)
20✔
347
    if return_type_enabled and "return" in annotations:
20✔
348
        return_type = annotations["return"]
20✔
349
        try:
20✔
350
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
20✔
351
            if STRICT_NUMERIC_TYPES:
20✔
352
                check_type_strict("return", r, return_type)
20✔
353
            else:
354
                check_type(r, return_type)
20✔
355
        except (TypeError, TypeCheckError):
20✔
356
            raise PyTAContractError(
20✔
357
                f"Return value {_display_value(r)} for {wrapped.__name__} did not match "
358
                f"expected type {_display_annotation(return_type)}"
359
            )
360

361
    # Check function postconditions
362
    if postconditions_enabled and not hasattr(target, "__postconditions__"):
20✔
363
        target.__postconditions__: list[tuple[str, CodeType, str]] = []
20✔
364
        return_val_var_name = _get_legal_return_val_var_name(
20✔
365
            {**wrapped.__globals__, **function_locals}
366
        )
367
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
20✔
368
        for postcondition in postconditions:
20✔
369
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
20✔
370
            try:
20✔
371
                compiled = compile(assertion, "<string>", "eval")
20✔
UNCOV
372
            except:
×
UNCOV
373
                _debug(
×
374
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
375
                )
UNCOV
376
                continue
×
377
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
20✔
378

379
    if ENABLE_CONTRACT_CHECKING and postconditions_enabled:
20✔
380
        _check_assertions(
20✔
381
            wrapped,
382
            function_locals,
383
            function_return_val=r,
384
            condition_type="postcondition",
385
        )
386

387
    return r
20✔
388

389

390
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
20✔
391
    """Ensure that `value` matches ``expected_type`` with strict type checking.
392

393
    This function enforces strict type distinctions within the numeric hierarchy (bool, int, float,
394
    complex), ensuring that the type of value is exactly the same as expected_type.
395
    """
396
    if not ENABLE_CONTRACT_CHECKING:
20✔
397
        return
20✔
398
    try:
20✔
399
        _check_inner_type(argname, value, expected_type)
20✔
400
    except (TypeError, TypeCheckError):
20✔
401
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
402

403

404
def _check_inner_type(argname: str, value: Any, expected_type: type) -> None:
20✔
405
    """Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
406
    collections (list[int], dicts[float]) and Union types (bool | int).
407
    """
408
    inner_types = get_args(expected_type)
20✔
409
    outer_type = get_origin(expected_type)
20✔
410
    if outer_type is None:
20✔
411
        if (
20✔
412
            (type(value) is bool and expected_type in {int, float, complex})
413
            or (type(value) is int and expected_type in {float, complex})
414
            or (type(value) is float and expected_type is complex)
415
        ):
416
            raise TypeError(
20✔
417
                f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
418
            )
419
        else:
420
            check_type(
20✔
421
                value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
422
            )
423
    elif outer_type is typing.Union:
20✔
424
        for inner_type in inner_types:
20✔
425
            try:
20✔
426
                _check_inner_type(argname, value, inner_type)
20✔
427
                return
20✔
428
            except (TypeError, TypeCheckError):
20✔
429
                pass
20✔
430
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
431
    elif outer_type in {list, set}:
20✔
432
        if isinstance(value, outer_type):
20✔
433
            for item in value:
20✔
434
                _check_inner_type(argname, item, inner_types[0])
20✔
435
        else:
436
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
437
    elif outer_type is dict:
20✔
438
        if isinstance(value, dict):
20✔
439
            for key, item in value.items():
20✔
440
                _check_inner_type(argname, key, inner_types[0])
20✔
441
                _check_inner_type(argname, item, inner_types[1])
20✔
442
        else:
443
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
444
    elif outer_type is tuple:
20✔
445
        if isinstance(value, tuple) and len(inner_types) == 2 and inner_types[1] is Ellipsis:
20✔
446
            for item in value:
20✔
447
                _check_inner_type(argname, item, inner_types[0])
20✔
448
        elif isinstance(value, tuple) and len(value) == len(inner_types):
20✔
449
            for item, inner_type in zip(value, inner_types):
20✔
450
                _check_inner_type(argname, item, inner_type)
20✔
451
        else:
452
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
453
    else:
454
        check_type(
20✔
455
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
456
        )
457

458

459
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
20✔
460
    """Returns potential suggestions for the given arg and its annotation"""
461
    try:
20✔
462
        if isinstance(arg, type) and issubclass(arg, annotation):
20✔
463
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
20✔
464
    except TypeError:
20✔
465
        pass
20✔
466

467
    return ""
20✔
468

469

470
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
20✔
471
    @wrapt.decorator
20✔
472
    def wrapper(wrapped, instance, args, kwargs):
20✔
473
        # Create an accumulator to store the instances mutated across this function call.
474
        # Store and restore existing mutated instance lists in case the instance method
475
        # executes another instance method.
476
        instance_klass = type(instance)
20✔
477
        mutated_instances_to_restore = None
20✔
478
        if hasattr(instance_klass, "__mutated_instances__"):
20✔
479
            mutated_instances_to_restore = getattr(instance_klass, "__mutated_instances__")
20✔
480
        setattr(instance_klass, "__mutated_instances__", [])
20✔
481

482
        try:
20✔
483
            r = _check_function_contracts(wrapped, instance, args, kwargs)
20✔
484
            if _instance_init_in_callstack(instance):
20✔
485
                return r
20✔
486
            _check_class_type_annotations(klass, instance)
20✔
487
            klass_mod = _get_module(klass)
20✔
488
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
20✔
489
                _check_invariants(instance, klass, klass_mod.__dict__)
20✔
490

491
                # Additionally check RI violations on PyTA-decorated instances that were mutated
492
                # across the function call.
493
                mutated_instances = getattr(instance_klass, "__mutated_instances__", [])
20✔
494
                for mutated_instance in mutated_instances:
20✔
495
                    # Mutated instances may be of parent class types so the invariants to check should also be
496
                    # for the parent class and not the child class.
497
                    mutated_instance_klass = type(mutated_instance)
20✔
498
                    mutated_instance_klass_mod = _get_module(mutated_instance_klass)
20✔
499
                    _check_invariants(
20✔
500
                        mutated_instance,
501
                        mutated_instance_klass,
502
                        mutated_instance_klass_mod.__dict__,
503
                    )
504
        except PyTAContractError as e:
20✔
505
            raise AssertionError(str(e)) from None
20✔
506
        else:
507
            return r
20✔
508
        finally:
509
            if mutated_instances_to_restore is None:
20✔
510
                delattr(instance_klass, "__mutated_instances__")
20✔
511
            else:
512
                setattr(instance_klass, "__mutated_instances__", mutated_instances_to_restore)
20✔
513

514
    return wrapper(wrapped)
20✔
515

516

517
def _instance_init_in_callstack(instance: Any) -> bool:
20✔
518
    """Return whether instance's init is part of the current callstack
519

520
    Note: due to the nature of the check, externally defined __init__ functions with
521
    'self' defined as the first parameter may pass this check.
522
    """
523
    frame = inspect.currentframe().f_back
20✔
524
    while frame:
20✔
525
        frame_context_name = inspect.getframeinfo(frame).function
20✔
526
        frame_context_self = frame.f_locals.get("self")
20✔
527
        frame_context_vars = frame.f_code.co_varnames
20✔
528
        if (
20✔
529
            frame_context_name == "__init__"
530
            and frame_context_self is instance
531
            and frame_context_vars[0] == "self"
532
        ):
533
            return True
20✔
534
        frame = frame.f_back
20✔
535
    return False
20✔
536

537

538
def _check_class_type_annotations(klass: type, instance: Any) -> None:
20✔
539
    """Check that the type annotations for the class still hold.
540

541
    Precondition:
542
        - isinstance(instance, klass)
543
    """
544
    klass_mod = _get_module(klass)
20✔
545
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
546

547
    for attr, annotation in cls_annotations.items():
20✔
548
        value = getattr(instance, attr)
20✔
549
        try:
20✔
550
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
551
            check_type(
20✔
552
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
553
            )
554
        except TypeCheckError:
20✔
555
            raise AssertionError(
20✔
556
                f"Value {_display_value(value)} for attribute {attr} did not match expected type "
557
                f"{_display_annotation(annotation)}"
558
            )
559

560

561
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
20✔
562
    """Check that the representation invariants for the instance are satisfied."""
563
    if hasattr(instance, "__pyta_currently_checking"):
20✔
564
        # If already checking invariants for this instance, skip to avoid infinite recursion
565
        return
20✔
566

567
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
20✔
568

569
    rep_invariants = getattr(klass, "__representation_invariants__", set())
20✔
570

571
    try:
20✔
572
        for invariant, compiled in rep_invariants:
20✔
573
            try:
20✔
574
                _debug(
20✔
575
                    "Checking representation invariant for "
576
                    f"{instance.__class__.__qualname__}: {invariant}"
577
                )
578
                check = eval(compiled, {**global_scope, "self": instance})
20✔
579
            except AssertionError as e:
20✔
580
                raise AssertionError(str(e)) from None
20✔
581
            except NameError as e:
20✔
582
                # Get the missing name
583
                missing = getattr(e, "name", None)
20✔
584
                if missing is None:
20✔
585
                    # Failsafe for version 3.9
586
                    message = re.search(r"name '(.+?)' is not defined", str(e))
4✔
587
                    if message:
4✔
588
                        missing = message.group(1)
4✔
589

590
                # Check if missing name is an attribute
591
                if missing is not None and hasattr(instance, missing):
20✔
592
                    print(
20✔
593
                        f"[WARNING] Could not find variable `{missing}` when evaluating representation invariant. Did you mean `self.{missing}`?",
594
                        file=sys.stderr,
595
                    )
596
                else:
597
                    _debug(f"Warning: could not evaluate representation invariant: {invariant}")
20✔
UNCOV
598
            except:
×
UNCOV
599
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
600
            else:
601
                if not check:
20✔
602
                    curr_attributes = ", ".join(
20✔
603
                        f"{k}: {_display_value(v)}"
604
                        for k, v in vars(instance).items()
605
                        if k != "__pyta_currently_checking"
606
                    )
607

608
                    curr_attributes = "{" + curr_attributes + "}"
20✔
609

610
                    raise PyTAContractError(
20✔
611
                        f'{instance.__class__.__name__} representation invariant "{invariant}" was violated for'
612
                        f" instance attributes {curr_attributes}"
613
                    )
614

615
    finally:
616
        delattr(instance, "__pyta_currently_checking")
20✔
617

618

619
def _get_legal_return_val_var_name(var_dict: dict) -> str:
20✔
620
    """
621
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
622
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
623
    postconditions.
624
    """
625
    legal_var_name = "__function_return_value__"
20✔
626

627
    while legal_var_name in var_dict:
20✔
UNCOV
628
        legal_var_name += "_"
×
629

630
    return legal_var_name
20✔
631

632

633
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
20✔
634
    """
635
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
636
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
637

638
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
639
    """
640

641
    if FUNCTION_RETURN_VALUE in assertion:
20✔
642
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
20✔
643
    return assertion
20✔
644

645

646
def _check_assertions(
20✔
647
    wrapped: Callable[..., Any],
648
    function_locals: dict,
649
    condition_type: str = "precondition",
650
    function_return_val: Any = None,
651
) -> None:
652
    """Check that the given assertions are still satisfied."""
653
    # Check bounded function
654
    if hasattr(wrapped, "__self__"):
20✔
655
        target = wrapped.__func__
20✔
656
    else:
657
        target = wrapped
20✔
658
    assertions = []
20✔
659
    if condition_type == "precondition":
20✔
660
        assertions = target.__preconditions__
20✔
661
    elif condition_type == "postcondition":
20✔
662
        assertions = target.__postconditions__
20✔
663
    for assertion_str, compiled, *return_val_var_name in assertions:
20✔
664
        return_val_dict = {}
20✔
665
        if condition_type == "postcondition":
20✔
666
            return_val_dict = {return_val_var_name[0]: function_return_val}
20✔
667
        try:
20✔
668
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
20✔
669
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
20✔
670
        except AssertionError as e:
20✔
671
            raise AssertionError(str(e)) from None
20✔
UNCOV
672
        except:
×
UNCOV
673
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
674
        else:
675
            if not check:
20✔
676
                arg_string = ", ".join(
20✔
677
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
678
                )
679
                arg_string = "{" + arg_string + "}"
20✔
680

681
                return_val_string = ""
20✔
682

683
                if condition_type == "postcondition":
20✔
684
                    return_val_string = f" and return value {function_return_val}"
20✔
685
                raise PyTAContractError(
20✔
686
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
687
                    f"violated for arguments {arg_string}{return_val_string}"
688
                )
689

690

691
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> list[str]:
20✔
692
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
693

694
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
695

696
    Currently only supports two forms:
697

698
    1. A single line of the form "<parse_token>: <cond>"
699
    2. A group of lines starting with "<parse_token>s:", where each subsequent
700
       line is of the form "- <cond>". Each line is considered a separate condition.
701
       The lines can be separated by blank lines, but no other text.
702
    """
703
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
20✔
704
        # Check if obj is an astroid node
705
        docstring = obj.doc_node.value
20✔
706
    else:
707
        docstring = getattr(obj, "__doc__") or ""
20✔
708
    lines = [line.strip() for line in docstring.split("\n")]
20✔
709
    assertion_lines = [
20✔
710
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
711
    ]
712

713
    if assertion_lines == []:
20✔
714
        return []
20✔
715

716
    first = assertion_lines[0]
20✔
717

718
    if lines[first].startswith(parse_token + ":"):
20✔
719
        return [lines[first][len(parse_token + ":") :].strip()]
20✔
720
    elif lines[first].startswith(parse_token + "s:"):
20✔
721
        assertions = []
20✔
722
        for line in lines[first + 1 :]:
20✔
723
            if line.startswith("-"):
20✔
724
                assertion = line[1:].strip()
20✔
725
                if hasattr(obj, "__qualname__"):
20✔
726
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
20✔
727
                assertions.append(assertion)
20✔
728
            elif line != "":
20✔
UNCOV
729
                break
×
730
        return assertions
20✔
731
    else:
UNCOV
732
        return []
×
733

734

735
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
20✔
736
    """Return a human-friendly representation of the given value.
737

738
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
739

740
    Preconditions:
741
        - max_length >= 5
742
    """
743
    s = repr(value)
20✔
744
    if not DEBUG_CONTRACTS and len(s) > max_length:
20✔
UNCOV
745
        i = (max_length - 3) // 2
×
UNCOV
746
        return s[:i] + "..." + s[-i:]
×
747
    else:
748
        return s
20✔
749

750

751
def _display_annotation(annotation: Any) -> str:
20✔
752
    """Return a human-friendly representation of the given type annotation.
753

754
    >>> _display_annotation(int)
755
    'int'
756
    >>> _display_annotation(list[int])
757
    'list[int]'
758
    >>> from typing import List
759
    >>> _display_annotation(List[int])
760
    'typing.List[int]'
761
    """
762
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
20✔
763
        return "None"
×
764
    if hasattr(annotation, "__origin__"):  # Generic type annotations
20✔
765
        return repr(annotation)
20✔
766
    elif hasattr(annotation, "__name__"):
20✔
767
        return annotation.__name__
20✔
768
    else:
UNCOV
769
        return repr(annotation)
×
770

771

772
def _get_module(obj: Any) -> ModuleType:
20✔
773
    """Return the module where obj was defined (normally obj.__module__).
774

775
    NOTE: this function defines a special case when using PyCharm and the file
776
    defining the object is "Run in Python Console". In this case, the pydevd runner
777
    renames the '__main__' module to 'pydev_umd', and so we need to access that
778
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
779
    to False.
780
    """
781
    module_name = obj.__module__
20✔
782
    module = sys.modules[module_name]
20✔
783

784
    if (
20✔
785
        module_name != "__main__"
786
        or not RENAME_MAIN_TO_PYDEV_UMD
787
        or _PYDEV_UMD_NAME not in sys.modules
788
    ):
789
        return module
20✔
790

791
    # Get a function/class name to check whether it is defined in the module
UNCOV
792
    if isinstance(obj, (FunctionType, type)):
×
UNCOV
793
        name = obj.__name__
×
794
    else:
795
        # For any other type of object, be conservative and just return the module
UNCOV
796
        return module
×
797

UNCOV
798
    if name in vars(module):
×
UNCOV
799
        return module
×
800
    else:
UNCOV
801
        return sys.modules[_PYDEV_UMD_NAME]
×
802

803

804
def _debug(msg: str) -> None:
20✔
805
    """Display a debugging message.
806

807
    Do nothing if DEBUG_CONTRACTS is False.
808
    """
809
    if not DEBUG_CONTRACTS:
20✔
810
        return
20✔
811
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
20✔
812
    logging.debug(msg)
20✔
813

814

815
def _set_invariants(klass: type) -> None:
20✔
816
    """Retrieve and set the representation invariants of this class"""
817
    # Update representation invariants from this class' docstring and those of its superclasses.
818
    rep_invariants: list[tuple[str, CodeType]] = []
20✔
819

820
    # Iterate over all inherited classes except builtins
821
    for cls in reversed(klass.__mro__):
20✔
822
        if "__representation_invariants__" in cls.__dict__:
20✔
823
            rep_invariants.extend(cls.__representation_invariants__)
20✔
824
        elif cls.__module__ != "builtins":
20✔
825
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
20✔
826
            # Try compiling assertions
827
            for assertion in assertions:
20✔
828
                try:
20✔
829
                    compiled = compile(assertion, "<string>", "eval")
20✔
UNCOV
830
                except:
×
UNCOV
831
                    _debug(
×
832
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
833
                    )
UNCOV
834
                    continue
×
835
                rep_invariants.append((assertion, compiled))
20✔
836

837
    setattr(klass, "__representation_invariants__", rep_invariants)
20✔
838

839

840
def validate_invariants(obj: object) -> None:
20✔
841
    """Check that the representation invariants of obj are satisfied."""
842
    klass = obj.__class__
20✔
843
    klass_mod = _get_module(klass)
20✔
844

845
    try:
20✔
846
        _check_invariants(obj, klass, klass_mod.__dict__)
20✔
847
    except PyTAContractError as e:
20✔
848
        raise AssertionError(str(e)) from None
20✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc