• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 19686422763

25 Nov 2025 10:47PM UTC coverage: 93.933% (+0.005%) from 93.928%
19686422763

Pull #1265

github

web-flow
Merge e0a8ef6a2 into 38734420f
Pull Request #1265: Extending AccumulationTable context manager to support multiple loop evaluations

3530 of 3758 relevant lines covered (93.93%)

17.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.56
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
from __future__ import annotations
20✔
14

15
import inspect
20✔
16
import logging
20✔
17
import re
20✔
18
import sys
20✔
19
import typing
20✔
20
from types import CodeType, FunctionType, ModuleType
20✔
21
from typing import (
20✔
22
    Any,
23
    Callable,
24
    Optional,
25
    TypeVar,
26
    Union,
27
    get_args,
28
    get_origin,
29
    overload,
30
)
31

32
import wrapt
20✔
33
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
20✔
34

35
# Configuration options
36

37
ENABLE_CONTRACT_CHECKING = True
20✔
38
"""
16✔
39
Set to True to enable contract checking.
40
"""
41

42
DEBUG_CONTRACTS = False
20✔
43
"""
16✔
44
Set to True to display debugging messages when checking contracts.
45
"""
46

47
RENAME_MAIN_TO_PYDEV_UMD = True
20✔
48
"""
16✔
49
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
50
In most cases you should not need to change this!
51
"""
52

53
STRICT_NUMERIC_TYPES = True
20✔
54
"""
16✔
55
Set to False to allow more specific numeric types to be accepted by more general type annotations.
56
"""
57

58
_PYDEV_UMD_NAME = "pydev_umd"
20✔
59

60

61
_DEFAULT_MAX_VALUE_LENGTH = 30
20✔
62
FUNCTION_RETURN_VALUE = "$return_value"
20✔
63

64

65
class PyTAContractError(Exception):
20✔
66
    """Error raised when a PyTA contract assertion is violated."""
67

68

69
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
20✔
70
    """Automatically check contracts for all functions and classes in the given modules.
71

72
    By default (when called with no arguments), the current module is used.
73

74
    Args:
75
        *mod_names: The names of modules to check contracts for. These modules must have been
76
            previously imported.
77
        decorate_main: True if the module being run (where __name__ == '__main__') should
78
            have contracts checked.
79
    """
80
    if not ENABLE_CONTRACT_CHECKING:
20✔
81
        return
×
82

83
    modules = []
20✔
84
    if decorate_main:
20✔
85
        mod_names = mod_names + ("__main__",)
×
86

87
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
88
        # with the "Run in Python Console" action.
89
        if RENAME_MAIN_TO_PYDEV_UMD:
×
90
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
91

92
    for module_name in mod_names:
20✔
93
        modules.append(sys.modules.get(module_name, None))
20✔
94

95
    for module in modules:
20✔
96
        if not module:
20✔
97
            # Module name was passed in incorrectly.
98
            continue
×
99
        for name, value in inspect.getmembers(module):
20✔
100
            if inspect.isfunction(value) or inspect.isclass(value):
20✔
101
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
20✔
102

103

104
# Wildcard Type Variable
105
Class = TypeVar("Class", bound=type)
20✔
106

107

108
@overload
109
def check_contracts(
110
    func: FunctionType,
111
    module_names: Optional[set[str]] = None,
112
    argument_types: bool = True,
113
    return_type: bool = True,
114
    preconditions: bool = True,
115
    postconditions: bool = True,
116
) -> FunctionType: ...
117

118

119
@overload
120
def check_contracts(
121
    func: Class,
122
    module_names: Optional[set[str]] = None,
123
    argument_types: bool = True,
124
    return_type: bool = True,
125
    preconditions: bool = True,
126
    postconditions: bool = True,
127
) -> Class: ...
128

129

130
def check_contracts(
20✔
131
    func_or_class: Union[Class, FunctionType] = None,
132
    *,
133
    module_names: Optional[set[str]] = None,
134
    argument_types: bool = True,
135
    return_type: bool = True,
136
    preconditions: bool = True,
137
    postconditions: bool = True,
138
) -> Union[Class, FunctionType]:
139
    """A decorator to enable contract checking for a function or class.
140

141
    When used with a class, all methods defined within the class have contract checking enabled.
142
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
143

144
    When used with functions, `check_contracts` accepts four optional boolean keyword arguments to selectively disable checks when set to `False`:
145

146
    - `argument_types`: check parameter type annotations
147
    - `return_type`: check the return type annotation
148
    - `preconditions`: check preconditions
149
    - `postconditions`: check postconditions
150

151
    By default, all four checks are enabled. These arguments only affect functions, and are ignored when `check_contracts` is applied to a class.
152

153
    Example:
154
        >>> from python_ta.contracts import check_contracts
155
        >>> @check_contracts
156
        ... def divide(x: int, y: int) -> int:
157
        ...     \"\"\"Return x // y.
158
        ...
159
        ...     Preconditions:
160
        ...        - y != 0
161
        ...     \"\"\"
162
        ...     return x // y
163
    """
164

165
    @wrapt.decorator
20✔
166
    def _enable_function_contracts(wrapped, instance, args, kwargs):
20✔
167
        """A decorator that enables checking contracts for a function."""
168
        try:
20✔
169
            if instance is not None and inspect.isclass(instance):
20✔
170
                # This is a class method, so there is no instance.
171
                return _check_function_contracts(
20✔
172
                    wrapped,
173
                    None,
174
                    args,
175
                    kwargs,
176
                    argument_types_enabled=argument_types,
177
                    return_type_enabled=return_type,
178
                    preconditions_enabled=preconditions,
179
                    postconditions_enabled=postconditions,
180
                )
181
            else:
182
                return _check_function_contracts(
20✔
183
                    wrapped,
184
                    instance,
185
                    args,
186
                    kwargs,
187
                    argument_types_enabled=argument_types,
188
                    return_type_enabled=return_type,
189
                    preconditions_enabled=preconditions,
190
                    postconditions_enabled=postconditions,
191
                )
192
        except PyTAContractError as e:
20✔
193
            raise AssertionError(str(e)) from None
20✔
194

195
    # Optional Arguments passed to the decorator
196
    if func_or_class is None:
20✔
197
        return wrapt.PartialCallableObjectProxy(
20✔
198
            check_contracts,
199
            module_names=module_names,
200
            argument_types=argument_types,
201
            return_type=return_type,
202
            preconditions=preconditions,
203
            postconditions=postconditions,
204
        )
205

206
    if not ENABLE_CONTRACT_CHECKING:
20✔
207
        return func_or_class
20✔
208

209
    if module_names is not None and func_or_class.__module__ not in module_names:
20✔
210
        _debug(
20✔
211
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
212
        )
213
        return func_or_class
20✔
214
    elif inspect.isroutine(func_or_class):
20✔
215
        return _enable_function_contracts(func_or_class)
20✔
216
    elif inspect.isclass(func_or_class):
20✔
217
        add_class_invariants(func_or_class)
20✔
218
        return func_or_class
20✔
219
    else:
220
        # Default action
221
        return func_or_class
×
222

223

224
def add_class_invariants(klass: type) -> None:
20✔
225
    """Modify the given class to check representation invariants and method contracts."""
226
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
20✔
227
        # This means the class has already been decorated
228
        return
×
229

230
    _set_invariants(klass)
20✔
231

232
    klass_mod = _get_module(klass)
20✔
233
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
20✔
234

235
    def new_setattr(self: klass, name: str, value: Any) -> None:
20✔
236
        """Set the value of the given attribute on self to the given value.
237

238
        Check representation invariants for this class when not within an instance method of the class.
239
        """
240
        if not ENABLE_CONTRACT_CHECKING:
20✔
241
            super(klass, self).__setattr__(name, value)
20✔
242
            return
20✔
243

244
        nonlocal cls_annotations
245
        if cls_annotations is None:
20✔
246
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
247

248
        if name in cls_annotations:
20✔
249
            try:
20✔
250
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
251
                check_type(
20✔
252
                    value,
253
                    cls_annotations[name],
254
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
255
                )
256
            except TypeCheckError:
20✔
257
                raise AssertionError(
20✔
258
                    f"Value {_display_value(value)} for attribute {name} did not match expected type "
259
                    f"{_display_annotation(cls_annotations[name])}"
260
                ) from None
261
        original_attr_value_exists = False
20✔
262
        original_attr_value = None
20✔
263
        if hasattr(self, name):
20✔
264
            original_attr_value_exists = True
20✔
265
            original_attr_value = super(klass, self).__getattribute__(name)
20✔
266
        super(klass, self).__setattr__(name, value)
20✔
267
        frame_locals = inspect.currentframe().f_back.f_locals
20✔
268
        caller_self = frame_locals.get("self")
20✔
269
        if not isinstance(caller_self, type(self)):
20✔
270
            # Only validating if the attribute is not being set in a instance/class method
271
            # AND caller_self is an instance of self's type
272
            if klass_mod is not None:
20✔
273
                try:
20✔
274
                    _check_invariants(self, klass, klass_mod.__dict__)
20✔
275
                except PyTAContractError as e:
20✔
276
                    if original_attr_value_exists:
20✔
277
                        super(klass, self).__setattr__(name, original_attr_value)
20✔
278
                    else:
279
                        super(klass, self).__delattr__(name)
20✔
280
                    raise AssertionError(str(e)) from None
20✔
281
        elif caller_self is not self:
20✔
282
            # Keep track of mutations to instances that are of the same type as caller_self (and are also not `self`)
283
            # to enforce RIs on them only after the caller function returns.
284
            caller_klass = type(caller_self)
20✔
285
            if hasattr(caller_klass, "__mutated_instances__"):
20✔
286
                mutated_instances = getattr(caller_klass, "__mutated_instances__")
20✔
287
                if self not in mutated_instances:
20✔
288
                    mutated_instances.append(self)
20✔
289

290
    for attr, value in klass.__dict__.items():
20✔
291
        if inspect.isroutine(value):
20✔
292
            if isinstance(value, (staticmethod, classmethod)):
20✔
293
                # Don't check rep invariants for staticmethod and classmethod
294
                setattr(klass, attr, check_contracts(value))
20✔
295
            else:
296
                setattr(klass, attr, _instance_method_wrapper(value, klass))
20✔
297

298
    klass.__setattr__ = new_setattr
20✔
299

300

301
def _check_function_contracts(
20✔
302
    wrapped,
303
    instance,
304
    args,
305
    kwargs,
306
    argument_types_enabled: bool = True,
307
    return_type_enabled: bool = True,
308
    preconditions_enabled: bool = True,
309
    postconditions_enabled: bool = True,
310
):
311
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
20✔
312
    if instance is not None:
20✔
313
        klass_mod = _get_module(type(instance))
20✔
314
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
20✔
315
    else:
316
        annotations = typing.get_type_hints(wrapped)
20✔
317
    args_with_self = args if instance is None else (instance,) + args
20✔
318

319
    if argument_types_enabled:
20✔
320
        # Check function parameter types
321
        for arg, param in zip(args_with_self, params):
20✔
322
            if param in annotations:
20✔
323
                try:
20✔
324
                    _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
20✔
325
                    if STRICT_NUMERIC_TYPES:
20✔
326
                        check_type_strict(param, arg, annotations[param])
20✔
327
                    else:
328
                        check_type(arg, annotations[param])
20✔
329
                except (TypeError, TypeCheckError):
20✔
330
                    additional_suggestions = _get_argument_suggestions(arg, annotations[param])
20✔
331

332
                    raise PyTAContractError(
20✔
333
                        f"Argument value {_display_value(arg)} for {wrapped.__name__} parameter {param} "
334
                        f"did not match expected type {_display_annotation(annotations[param])}"
335
                        + (f"\n{additional_suggestions}" if additional_suggestions else "")
336
                    )
337

338
    function_locals = dict(zip(params, args_with_self))
20✔
339

340
    # Check bounded function
341
    if hasattr(wrapped, "__self__"):
20✔
342
        target = wrapped.__func__
20✔
343
    else:
344
        target = wrapped
20✔
345

346
    # Check function preconditions
347
    if not hasattr(target, "__preconditions__") and preconditions_enabled:
20✔
348
        target.__preconditions__: list[tuple[str, CodeType]] = []
20✔
349
        preconditions = parse_assertions(wrapped)
20✔
350
        for precondition in preconditions:
20✔
351
            try:
20✔
352
                compiled = compile(precondition, "<string>", "eval")
20✔
353
            except:
20✔
354
                _debug(
20✔
355
                    f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
356
                )
357
                continue
20✔
358
            target.__preconditions__.append((precondition, compiled))
20✔
359

360
    if ENABLE_CONTRACT_CHECKING and preconditions_enabled:
20✔
361
        _check_assertions(wrapped, function_locals)
20✔
362

363
    # Check return type
364
    r = wrapped(*args, **kwargs)
20✔
365
    if return_type_enabled and "return" in annotations:
20✔
366
        return_type = annotations["return"]
20✔
367
        try:
20✔
368
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
20✔
369
            if STRICT_NUMERIC_TYPES:
20✔
370
                check_type_strict("return", r, return_type)
20✔
371
            else:
372
                check_type(r, return_type)
20✔
373
        except (TypeError, TypeCheckError):
20✔
374
            raise PyTAContractError(
20✔
375
                f"Return value {_display_value(r)} for {wrapped.__name__} did not match "
376
                f"expected type {_display_annotation(return_type)}"
377
            )
378

379
    # Check function postconditions
380
    if postconditions_enabled and not hasattr(target, "__postconditions__"):
20✔
381
        target.__postconditions__: list[tuple[str, CodeType, str]] = []
20✔
382
        return_val_var_name = _get_legal_return_val_var_name(
20✔
383
            {**wrapped.__globals__, **function_locals}
384
        )
385
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
20✔
386
        for postcondition in postconditions:
20✔
387
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
20✔
388
            try:
20✔
389
                compiled = compile(assertion, "<string>", "eval")
20✔
390
            except:
×
391
                _debug(
×
392
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
393
                )
394
                continue
×
395
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
20✔
396

397
    if ENABLE_CONTRACT_CHECKING and postconditions_enabled:
20✔
398
        _check_assertions(
20✔
399
            wrapped,
400
            function_locals,
401
            function_return_val=r,
402
            condition_type="postcondition",
403
        )
404

405
    return r
20✔
406

407

408
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
20✔
409
    """Ensure that `value` matches ``expected_type`` with strict type checking.
410

411
    This function enforces strict type distinctions within the numeric hierarchy (bool, int, float,
412
    complex), ensuring that the type of value is exactly the same as expected_type.
413
    """
414
    if not ENABLE_CONTRACT_CHECKING:
20✔
415
        return
20✔
416
    try:
20✔
417
        _check_inner_type(argname, value, expected_type)
20✔
418
    except (TypeError, TypeCheckError):
20✔
419
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
420

421

422
def _check_inner_type(argname: str, value: Any, expected_type: type) -> None:
20✔
423
    """Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
424
    collections (list[int], dicts[float]) and Union types (bool | int).
425
    """
426
    inner_types = get_args(expected_type)
20✔
427
    outer_type = get_origin(expected_type)
20✔
428
    if outer_type is None:
20✔
429
        if (
20✔
430
            (type(value) is bool and expected_type in {int, float, complex})
431
            or (type(value) is int and expected_type in {float, complex})
432
            or (type(value) is float and expected_type is complex)
433
        ):
434
            raise TypeError(
20✔
435
                f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
436
            )
437
        else:
438
            check_type(
20✔
439
                value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
440
            )
441
    elif outer_type is typing.Union:
20✔
442
        for inner_type in inner_types:
20✔
443
            try:
20✔
444
                _check_inner_type(argname, value, inner_type)
20✔
445
                return
20✔
446
            except (TypeError, TypeCheckError):
20✔
447
                pass
20✔
448
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
449
    elif outer_type in {list, set}:
20✔
450
        if isinstance(value, outer_type):
20✔
451
            for item in value:
20✔
452
                _check_inner_type(argname, item, inner_types[0])
20✔
453
        else:
454
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
455
    elif outer_type is dict:
20✔
456
        if isinstance(value, dict):
20✔
457
            for key, item in value.items():
20✔
458
                _check_inner_type(argname, key, inner_types[0])
20✔
459
                _check_inner_type(argname, item, inner_types[1])
20✔
460
        else:
461
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
462
    elif outer_type is tuple:
20✔
463
        if isinstance(value, tuple) and len(inner_types) == 2 and inner_types[1] is Ellipsis:
20✔
464
            for item in value:
20✔
465
                _check_inner_type(argname, item, inner_types[0])
20✔
466
        elif isinstance(value, tuple) and len(value) == len(inner_types):
20✔
467
            for item, inner_type in zip(value, inner_types):
20✔
468
                _check_inner_type(argname, item, inner_type)
20✔
469
        else:
470
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
471
    else:
472
        check_type(
20✔
473
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
474
        )
475

476

477
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
20✔
478
    """Returns potential suggestions for the given arg and its annotation"""
479
    try:
20✔
480
        if isinstance(arg, type) and issubclass(arg, annotation):
20✔
481
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
20✔
482
    except TypeError:
20✔
483
        pass
20✔
484

485
    return ""
20✔
486

487

488
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
20✔
489
    @wrapt.decorator
20✔
490
    def wrapper(wrapped, instance, args, kwargs):
20✔
491
        # Create an accumulator to store the instances mutated across this function call.
492
        # Store and restore existing mutated instance lists in case the instance method
493
        # executes another instance method.
494
        instance_klass = type(instance)
20✔
495
        mutated_instances_to_restore = None
20✔
496
        if hasattr(instance_klass, "__mutated_instances__"):
20✔
497
            mutated_instances_to_restore = getattr(instance_klass, "__mutated_instances__")
20✔
498
        setattr(instance_klass, "__mutated_instances__", [])
20✔
499

500
        try:
20✔
501
            r = _check_function_contracts(wrapped, instance, args, kwargs)
20✔
502
            if _instance_init_in_callstack(instance):
20✔
503
                return r
20✔
504
            _check_class_type_annotations(klass, instance)
20✔
505
            klass_mod = _get_module(klass)
20✔
506
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
20✔
507
                _check_invariants(instance, klass, klass_mod.__dict__)
20✔
508

509
                # Additionally check RI violations on PyTA-decorated instances that were mutated
510
                # across the function call.
511
                mutated_instances = getattr(instance_klass, "__mutated_instances__", [])
20✔
512
                for mutated_instance in mutated_instances:
20✔
513
                    # Mutated instances may be of parent class types so the invariants to check should also be
514
                    # for the parent class and not the child class.
515
                    mutated_instance_klass = type(mutated_instance)
20✔
516
                    mutated_instance_klass_mod = _get_module(mutated_instance_klass)
20✔
517
                    _check_invariants(
20✔
518
                        mutated_instance,
519
                        mutated_instance_klass,
520
                        mutated_instance_klass_mod.__dict__,
521
                    )
522
        except PyTAContractError as e:
20✔
523
            raise AssertionError(str(e)) from None
20✔
524
        else:
525
            return r
20✔
526
        finally:
527
            if mutated_instances_to_restore is None:
20✔
528
                delattr(instance_klass, "__mutated_instances__")
20✔
529
            else:
530
                setattr(instance_klass, "__mutated_instances__", mutated_instances_to_restore)
20✔
531

532
    return wrapper(wrapped)
20✔
533

534

535
def _instance_init_in_callstack(instance: Any) -> bool:
20✔
536
    """Return whether instance's init is part of the current callstack
537

538
    Note: due to the nature of the check, externally defined __init__ functions with
539
    'self' defined as the first parameter may pass this check.
540
    """
541
    frame = inspect.currentframe().f_back
20✔
542
    while frame:
20✔
543
        frame_context_name = inspect.getframeinfo(frame).function
20✔
544
        frame_context_self = frame.f_locals.get("self")
20✔
545
        frame_context_vars = frame.f_code.co_varnames
20✔
546
        if (
20✔
547
            frame_context_name == "__init__"
548
            and frame_context_self is instance
549
            and frame_context_vars[0] == "self"
550
        ):
551
            return True
20✔
552
        frame = frame.f_back
20✔
553
    return False
20✔
554

555

556
def _check_class_type_annotations(klass: type, instance: Any) -> None:
20✔
557
    """Check that the type annotations for the class still hold.
558

559
    Precondition:
560
        - isinstance(instance, klass)
561
    """
562
    klass_mod = _get_module(klass)
20✔
563
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
564

565
    for attr, annotation in cls_annotations.items():
20✔
566
        value = getattr(instance, attr)
20✔
567
        try:
20✔
568
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
569
            check_type(
20✔
570
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
571
            )
572
        except TypeCheckError:
20✔
573
            raise AssertionError(
20✔
574
                f"Value {_display_value(value)} for attribute {attr} did not match expected type "
575
                f"{_display_annotation(annotation)}"
576
            )
577

578

579
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
20✔
580
    """Check that the representation invariants for the instance are satisfied."""
581
    if hasattr(instance, "__pyta_currently_checking"):
20✔
582
        # If already checking invariants for this instance, skip to avoid infinite recursion
583
        return
20✔
584

585
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
20✔
586

587
    rep_invariants = getattr(klass, "__representation_invariants__", set())
20✔
588

589
    try:
20✔
590
        for invariant, compiled in rep_invariants:
20✔
591
            try:
20✔
592
                _debug(
20✔
593
                    "Checking representation invariant for "
594
                    f"{instance.__class__.__qualname__}: {invariant}"
595
                )
596
                check = eval(compiled, {**global_scope, "self": instance})
20✔
597
            except AssertionError as e:
20✔
598
                raise AssertionError(str(e)) from None
20✔
599
            except NameError as e:
20✔
600
                # Get the missing name
601
                missing = getattr(e, "name", None)
20✔
602
                if missing is None:
20✔
603
                    # Failsafe for version 3.9
604
                    message = re.search(r"name '(.+?)' is not defined", str(e))
4✔
605
                    if message:
4✔
606
                        missing = message.group(1)
4✔
607

608
                # Check if missing name is an attribute
609
                if missing is not None and hasattr(instance, missing):
20✔
610
                    print(
20✔
611
                        f"[WARNING] Could not find variable `{missing}` when evaluating representation invariant. Did you mean `self.{missing}`?",
612
                        file=sys.stderr,
613
                    )
614
                else:
615
                    _debug(f"Warning: could not evaluate representation invariant: {invariant}")
20✔
616
            except:
×
617
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
618
            else:
619
                if not check:
20✔
620
                    curr_attributes = ", ".join(
20✔
621
                        f"{k}: {_display_value(v)}"
622
                        for k, v in vars(instance).items()
623
                        if k != "__pyta_currently_checking"
624
                    )
625

626
                    curr_attributes = "{" + curr_attributes + "}"
20✔
627

628
                    raise PyTAContractError(
20✔
629
                        f'{instance.__class__.__name__} representation invariant "{invariant}" was violated for'
630
                        f" instance attributes {curr_attributes}"
631
                    )
632

633
    finally:
634
        delattr(instance, "__pyta_currently_checking")
20✔
635

636

637
def _get_legal_return_val_var_name(var_dict: dict) -> str:
20✔
638
    """
639
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
640
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
641
    postconditions.
642
    """
643
    legal_var_name = "__function_return_value__"
20✔
644

645
    while legal_var_name in var_dict:
20✔
646
        legal_var_name += "_"
×
647

648
    return legal_var_name
20✔
649

650

651
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
20✔
652
    """
653
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
654
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
655

656
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
657
    """
658

659
    if FUNCTION_RETURN_VALUE in assertion:
20✔
660
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
20✔
661
    return assertion
20✔
662

663

664
def _check_assertions(
20✔
665
    wrapped: Callable[..., Any],
666
    function_locals: dict,
667
    condition_type: str = "precondition",
668
    function_return_val: Any = None,
669
) -> None:
670
    """Check that the given assertions are still satisfied."""
671
    # Check bounded function
672
    if hasattr(wrapped, "__self__"):
20✔
673
        target = wrapped.__func__
20✔
674
    else:
675
        target = wrapped
20✔
676
    assertions = []
20✔
677
    if condition_type == "precondition":
20✔
678
        assertions = target.__preconditions__
20✔
679
    elif condition_type == "postcondition":
20✔
680
        assertions = target.__postconditions__
20✔
681
    for assertion_str, compiled, *return_val_var_name in assertions:
20✔
682
        return_val_dict = {}
20✔
683
        if condition_type == "postcondition":
20✔
684
            return_val_dict = {return_val_var_name[0]: function_return_val}
20✔
685
        try:
20✔
686
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
20✔
687
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
20✔
688
        except AssertionError as e:
20✔
689
            raise AssertionError(str(e)) from None
20✔
690
        except:
×
691
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
692
        else:
693
            if not check:
20✔
694
                arg_string = ", ".join(
20✔
695
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
696
                )
697
                arg_string = "{" + arg_string + "}"
20✔
698

699
                return_val_string = ""
20✔
700

701
                if condition_type == "postcondition":
20✔
702
                    return_val_string = f" and return value {function_return_val}"
20✔
703
                raise PyTAContractError(
20✔
704
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
705
                    f"violated for arguments {arg_string}{return_val_string}"
706
                )
707

708

709
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> list[str]:
20✔
710
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
711

712
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
713

714
    Currently only supports two forms:
715

716
    1. A single line of the form "<parse_token>: <cond>"
717
    2. A group of lines starting with "<parse_token>s:", where each subsequent
718
       line is of the form "- <cond>". Each line is considered a separate condition.
719
       The lines can be separated by blank lines, but no other text.
720
    """
721
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
20✔
722
        # Check if obj is an astroid node
723
        docstring = obj.doc_node.value
20✔
724
    else:
725
        docstring = getattr(obj, "__doc__") or ""
20✔
726
    lines = [line.strip() for line in docstring.split("\n")]
20✔
727
    assertion_lines = [
20✔
728
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
729
    ]
730

731
    if assertion_lines == []:
20✔
732
        return []
20✔
733

734
    first = assertion_lines[0]
20✔
735

736
    if lines[first].startswith(parse_token + ":"):
20✔
737
        return [lines[first][len(parse_token + ":") :].strip()]
20✔
738
    elif lines[first].startswith(parse_token + "s:"):
20✔
739
        assertions = []
20✔
740
        for line in lines[first + 1 :]:
20✔
741
            if line.startswith("-"):
20✔
742
                assertion = line[1:].strip()
20✔
743
                if hasattr(obj, "__qualname__"):
20✔
744
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
20✔
745
                assertions.append(assertion)
20✔
746
            elif line != "":
20✔
747
                break
×
748
        return assertions
20✔
749
    else:
750
        return []
×
751

752

753
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
20✔
754
    """Return a human-friendly representation of the given value.
755

756
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
757

758
    Preconditions:
759
        - max_length >= 5
760
    """
761
    s = repr(value)
20✔
762
    if not DEBUG_CONTRACTS and len(s) > max_length:
20✔
763
        i = (max_length - 3) // 2
×
764
        return s[:i] + "..." + s[-i:]
×
765
    else:
766
        return s
20✔
767

768

769
def _display_annotation(annotation: Any) -> str:
20✔
770
    """Return a human-friendly representation of the given type annotation.
771

772
    >>> _display_annotation(int)
773
    'int'
774
    >>> _display_annotation(list[int])
775
    'list[int]'
776
    >>> from typing import List
777
    >>> _display_annotation(List[int])
778
    'typing.List[int]'
779
    """
780
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
20✔
781
        return "None"
×
782
    if hasattr(annotation, "__origin__"):  # Generic type annotations
20✔
783
        return repr(annotation)
20✔
784
    elif hasattr(annotation, "__name__"):
20✔
785
        return annotation.__name__
20✔
786
    else:
787
        return repr(annotation)
×
788

789

790
def _get_module(obj: Any) -> ModuleType:
20✔
791
    """Return the module where obj was defined (normally obj.__module__).
792

793
    NOTE: this function defines a special case when using PyCharm and the file
794
    defining the object is "Run in Python Console". In this case, the pydevd runner
795
    renames the '__main__' module to 'pydev_umd', and so we need to access that
796
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
797
    to False.
798
    """
799
    module_name = obj.__module__
20✔
800
    module = sys.modules[module_name]
20✔
801

802
    if (
20✔
803
        module_name != "__main__"
804
        or not RENAME_MAIN_TO_PYDEV_UMD
805
        or _PYDEV_UMD_NAME not in sys.modules
806
    ):
807
        return module
20✔
808

809
    # Get a function/class name to check whether it is defined in the module
810
    if isinstance(obj, (FunctionType, type)):
×
811
        name = obj.__name__
×
812
    else:
813
        # For any other type of object, be conservative and just return the module
814
        return module
×
815

816
    if name in vars(module):
×
817
        return module
×
818
    else:
819
        return sys.modules[_PYDEV_UMD_NAME]
×
820

821

822
def _debug(msg: str) -> None:
20✔
823
    """Display a debugging message.
824

825
    Do nothing if DEBUG_CONTRACTS is False.
826
    """
827
    if not DEBUG_CONTRACTS:
20✔
828
        return
20✔
829
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
20✔
830
    logging.debug(msg)
20✔
831

832

833
def _set_invariants(klass: type) -> None:
20✔
834
    """Retrieve and set the representation invariants of this class"""
835
    # Update representation invariants from this class' docstring and those of its superclasses.
836
    rep_invariants: list[tuple[str, CodeType]] = []
20✔
837

838
    # Iterate over all inherited classes except builtins
839
    for cls in reversed(klass.__mro__):
20✔
840
        if "__representation_invariants__" in cls.__dict__:
20✔
841
            rep_invariants.extend(cls.__representation_invariants__)
20✔
842
        elif cls.__module__ != "builtins":
20✔
843
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
20✔
844
            # Try compiling assertions
845
            for assertion in assertions:
20✔
846
                try:
20✔
847
                    compiled = compile(assertion, "<string>", "eval")
20✔
848
                except:
×
849
                    _debug(
×
850
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
851
                    )
852
                    continue
×
853
                rep_invariants.append((assertion, compiled))
20✔
854

855
    setattr(klass, "__representation_invariants__", rep_invariants)
20✔
856

857

858
def validate_invariants(obj: object) -> None:
20✔
859
    """Check that the representation invariants of obj are satisfied."""
860
    klass = obj.__class__
20✔
861
    klass_mod = _get_module(klass)
20✔
862

863
    try:
20✔
864
        _check_invariants(obj, klass, klass_mod.__dict__)
20✔
865
    except PyTAContractError as e:
20✔
866
        raise AssertionError(str(e)) from None
20✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc