• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 13877193973

15 Mar 2025 10:17PM UTC coverage: 92.996% (+0.03%) from 92.967%
13877193973

Pull #1162

github

web-flow
Merge 1c11ccb99 into 2b00cbabf
Pull Request #1162: Relax Representation Invariant Checking Logic for Recursive Classes

23 of 23 new or added lines in 1 file covered. (100.0%)

20 existing lines in 1 file now uncovered.

3293 of 3541 relevant lines covered (93.0%)

17.68 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.29
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
from __future__ import annotations
20✔
14

15
import inspect
20✔
16
import logging
20✔
17
import sys
20✔
18
import typing
20✔
19
from types import CodeType, FunctionType, ModuleType
20✔
20
from typing import (
20✔
21
    Any,
22
    Callable,
23
    Collection,
24
    Optional,
25
    TypeVar,
26
    Union,
27
    get_args,
28
    get_origin,
29
    overload,
30
)
31

32
import wrapt
20✔
33
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
20✔
34

35
# Configuration options
36

37
ENABLE_CONTRACT_CHECKING = True
20✔
38
"""
16✔
39
Set to True to enable contract checking.
40
"""
41

42
DEBUG_CONTRACTS = False
20✔
43
"""
16✔
44
Set to True to display debugging messages when checking contracts.
45
"""
46

47
RENAME_MAIN_TO_PYDEV_UMD = True
20✔
48
"""
16✔
49
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
50
In most cases you should not need to change this!
51
"""
52

53
STRICT_NUMERIC_TYPES = True
20✔
54
"""
16✔
55
Set to False to allow more specific numeric types to be accepted by more general type annotations.
56
"""
57

58
_PYDEV_UMD_NAME = "pydev_umd"
20✔
59

60

61
_DEFAULT_MAX_VALUE_LENGTH = 30
20✔
62
FUNCTION_RETURN_VALUE = "$return_value"
20✔
63

64

65
class PyTAContractError(Exception):
20✔
66
    """Error raised when a PyTA contract assertion is violated."""
67

68

69
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
20✔
70
    """Automatically check contracts for all functions and classes in the given modules.
71

72
    By default (when called with no arguments), the current module is used.
73

74
    Args:
75
        *mod_names: The names of modules to check contracts for. These modules must have been
76
            previously imported.
77
        decorate_main: True if the module being run (where __name__ == '__main__') should
78
            have contracts checked.
79
    """
80
    if not ENABLE_CONTRACT_CHECKING:
20✔
81
        return
×
82

83
    modules = []
20✔
84
    if decorate_main:
20✔
85
        mod_names = mod_names + ("__main__",)
×
86

87
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
88
        # with the "Run in Python Console" action.
89
        if RENAME_MAIN_TO_PYDEV_UMD:
×
90
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
91

92
    for module_name in mod_names:
20✔
93
        modules.append(sys.modules.get(module_name, None))
20✔
94

95
    for module in modules:
20✔
96
        if not module:
20✔
97
            # Module name was passed in incorrectly.
98
            continue
×
99
        for name, value in inspect.getmembers(module):
20✔
100
            if inspect.isfunction(value) or inspect.isclass(value):
20✔
101
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
20✔
102

103

104
@wrapt.decorator
20✔
105
def _enable_function_contracts(wrapped, instance, args, kwargs):
20✔
106
    """A decorator that enables checking contracts for a function."""
107
    try:
20✔
108
        if instance is not None and inspect.isclass(instance):
20✔
109
            # This is a class method, so there is no instance.
110
            return _check_function_contracts(wrapped, None, args, kwargs)
20✔
111
        else:
112
            return _check_function_contracts(wrapped, instance, args, kwargs)
20✔
113
    except PyTAContractError as e:
20✔
114
        raise AssertionError(str(e)) from None
20✔
115

116

117
# Wildcard Type Variable
118
Class = TypeVar("Class", bound=type)
20✔
119

120

121
@overload
20✔
122
def check_contracts(
20✔
123
    func: FunctionType, module_names: Optional[set[str]] = None
124
) -> FunctionType: ...
125

126

127
@overload
20✔
128
def check_contracts(func: Class, module_names: Optional[set[str]] = None) -> Class: ...
20✔
129

130

131
def check_contracts(
20✔
132
    func_or_class: Union[Class, FunctionType], module_names: Optional[set[str]] = None
133
) -> Union[Class, FunctionType]:
134
    """A decorator to enable contract checking for a function or class.
135

136
    When used with a class, all methods defined within the class have contract checking enabled.
137
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
138

139
    Example:
140

141
        >>> from python_ta.contracts import check_contracts
142
        >>> @check_contracts
143
        ... def divide(x: int, y: int) -> int:
144
        ...     \"\"\"Return x // y.
145
        ...
146
        ...     Preconditions:
147
        ...        - y != 0
148
        ...     \"\"\"
149
        ...     return x // y
150
    """
151
    if not ENABLE_CONTRACT_CHECKING:
20✔
152
        return func_or_class
20✔
153

154
    if module_names is not None and func_or_class.__module__ not in module_names:
20✔
155
        _debug(
20✔
156
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
157
        )
158
        return func_or_class
20✔
159
    elif inspect.isroutine(func_or_class):
20✔
160
        return _enable_function_contracts(func_or_class)
20✔
161
    elif inspect.isclass(func_or_class):
20✔
162
        add_class_invariants(func_or_class)
20✔
163
        return func_or_class
20✔
164
    else:
165
        # Default action
166
        return func_or_class
×
167

168

169
def add_class_invariants(klass: type) -> None:
20✔
170
    """Modify the given class to check representation invariants and method contracts."""
171
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
20✔
172
        # This means the class has already been decorated
173
        return
×
174

175
    _set_invariants(klass)
20✔
176

177
    klass_mod = _get_module(klass)
20✔
178
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
20✔
179

180
    def new_setattr(self: klass, name: str, value: Any) -> None:
20✔
181
        """Set the value of the given attribute on self to the given value.
182

183
        Check representation invariants for this class when not within an instance method of the class.
184
        """
185
        if not ENABLE_CONTRACT_CHECKING:
20✔
186
            super(klass, self).__setattr__(name, value)
20✔
187
            return
20✔
188

189
        nonlocal cls_annotations
190
        if cls_annotations is None:
20✔
191
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
192

193
        if name in cls_annotations:
20✔
194
            try:
20✔
195
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
196
                check_type(
20✔
197
                    value,
198
                    cls_annotations[name],
199
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
200
                )
201
            except TypeCheckError:
20✔
202
                raise AssertionError(
20✔
203
                    f"Value {_display_value(value)} for attribute {name} did not match expected type "
204
                    f"{_display_annotation(cls_annotations[name])}"
205
                ) from None
206
        original_attr_value_exists = False
20✔
207
        original_attr_value = None
20✔
208
        if hasattr(self, name):
20✔
209
            original_attr_value_exists = True
20✔
210
            original_attr_value = super(klass, self).__getattribute__(name)
20✔
211
        super(klass, self).__setattr__(name, value)
20✔
212
        frame_locals = inspect.currentframe().f_back.f_locals
20✔
213
        caller_self = frame_locals.get("self")
20✔
214
        if not isinstance(caller_self, type(self)):
20✔
215
            # Only validating if the attribute is not being set in a instance/class method
216
            # AND self is an instance of caller_self's type -- aka caller_self is equal to
217
            # self, caller_self is another instance of the same class as self,
218
            # or self is an instance of caller_self's parent class.
219
            if klass_mod is not None:
20✔
220
                try:
20✔
221
                    _check_invariants(self, klass, klass_mod.__dict__)
20✔
222
                except PyTAContractError as e:
20✔
223
                    if original_attr_value_exists:
20✔
224
                        super(klass, self).__setattr__(name, original_attr_value)
20✔
225
                    else:
226
                        super(klass, self).__delattr__(name)
20✔
227
                    raise AssertionError(str(e)) from None
20✔
228
        else:
229
            caller_klass = type(caller_self)
20✔
230
            if "__mutated_instances__" in caller_klass.__dict__:
20✔
231
                mutable_instances = caller_klass.__dict__["__mutated_instances__"]
20✔
232
                if self not in mutable_instances:
20✔
233
                    mutable_instances.append(self)
20✔
234

235
    for attr, value in klass.__dict__.items():
20✔
236
        if inspect.isroutine(value):
20✔
237
            if isinstance(value, (staticmethod, classmethod)):
20✔
238
                # Don't check rep invariants for staticmethod and classmethod
239
                setattr(klass, attr, check_contracts(value))
20✔
240
            else:
241
                setattr(klass, attr, _instance_method_wrapper(value, klass))
20✔
242

243
    klass.__setattr__ = new_setattr
20✔
244

245

246
def _check_function_contracts(wrapped, instance, args, kwargs):
20✔
247
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
20✔
248
    if instance is not None:
20✔
249
        klass_mod = _get_module(type(instance))
20✔
250
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
20✔
251
    else:
252
        annotations = typing.get_type_hints(wrapped)
20✔
253
    args_with_self = args if instance is None else (instance,) + args
20✔
254

255
    # Check function parameter types
256
    for arg, param in zip(args_with_self, params):
20✔
257
        if param in annotations:
20✔
258
            try:
20✔
259
                _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
20✔
260
                if STRICT_NUMERIC_TYPES:
20✔
261
                    check_type_strict(param, arg, annotations[param])
20✔
262
                else:
263
                    check_type(arg, annotations[param])
20✔
264
            except (TypeError, TypeCheckError):
20✔
265
                additional_suggestions = _get_argument_suggestions(arg, annotations[param])
20✔
266

267
                raise PyTAContractError(
20✔
268
                    f"Argument value {_display_value(arg)} for {wrapped.__name__} parameter {param} "
269
                    f"did not match expected type {_display_annotation(annotations[param])}"
270
                    + (f"\n{additional_suggestions}" if additional_suggestions else "")
271
                )
272

273
    function_locals = dict(zip(params, args_with_self))
20✔
274

275
    # Check bounded function
276
    if hasattr(wrapped, "__self__"):
20✔
277
        target = wrapped.__func__
20✔
278
    else:
279
        target = wrapped
20✔
280

281
    # Check function preconditions
282
    if not hasattr(target, "__preconditions__"):
20✔
283
        target.__preconditions__: list[tuple[str, CodeType]] = []
20✔
284
        preconditions = parse_assertions(wrapped)
20✔
285
        for precondition in preconditions:
20✔
286
            try:
20✔
287
                compiled = compile(precondition, "<string>", "eval")
20✔
288
            except:
20✔
289
                _debug(
20✔
290
                    f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
291
                )
292
                continue
20✔
293
            target.__preconditions__.append((precondition, compiled))
20✔
294

295
    if ENABLE_CONTRACT_CHECKING:
20✔
296
        _check_assertions(wrapped, function_locals)
20✔
297

298
    # Check return type
299
    r = wrapped(*args, **kwargs)
20✔
300
    if "return" in annotations:
20✔
301
        return_type = annotations["return"]
20✔
302
        try:
20✔
303
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
20✔
304
            if STRICT_NUMERIC_TYPES:
20✔
305
                check_type_strict("return", r, return_type)
20✔
306
            else:
307
                check_type(r, return_type)
20✔
308
        except (TypeError, TypeCheckError):
20✔
309
            raise PyTAContractError(
20✔
310
                f"Return value {_display_value(r)} for {wrapped.__name__} did not match "
311
                f"expected type {_display_annotation(return_type)}"
312
            )
313

314
    # Check function postconditions
315
    if not hasattr(target, "__postconditions__"):
20✔
316
        target.__postconditions__: list[tuple[str, CodeType, str]] = []
20✔
317
        return_val_var_name = _get_legal_return_val_var_name(
20✔
318
            {**wrapped.__globals__, **function_locals}
319
        )
320
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
20✔
321
        for postcondition in postconditions:
20✔
322
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
20✔
323
            try:
20✔
324
                compiled = compile(assertion, "<string>", "eval")
20✔
325
            except:
×
326
                _debug(
×
327
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
328
                )
329
                continue
×
330
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
20✔
331

332
    if ENABLE_CONTRACT_CHECKING:
20✔
333
        _check_assertions(
20✔
334
            wrapped,
335
            function_locals,
336
            function_return_val=r,
337
            condition_type="postcondition",
338
        )
339

340
    return r
20✔
341

342

343
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
20✔
344
    """Ensure that `value` matches ``expected_type`` with strict type checking.
345

346
    This function enforces strict type distinctions within the numeric hierarchy (bool, int, float,
347
    complex), ensuring that the type of value is exactly the same as expected_type.
348
    """
349
    if not ENABLE_CONTRACT_CHECKING:
20✔
350
        return
20✔
351
    try:
20✔
352
        _check_inner_type(argname, value, expected_type)
20✔
353
    except (TypeError, TypeCheckError):
20✔
354
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
355

356

357
def _check_inner_type(argname: str, value: Any, expected_type: type) -> None:
20✔
358
    """Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
359
    collections (list[int], dicts[float]) and Union types (bool | int).
360
    """
361
    inner_types = get_args(expected_type)
20✔
362
    outer_type = get_origin(expected_type)
20✔
363
    if outer_type is None:
20✔
364
        if (
20✔
365
            (type(value) is bool and expected_type in {int, float, complex})
366
            or (type(value) is int and expected_type in {float, complex})
367
            or (type(value) is float and expected_type is complex)
368
        ):
369
            raise TypeError(
20✔
370
                f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
371
            )
372
        else:
373
            check_type(
20✔
374
                value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
375
            )
376
    elif outer_type is typing.Union:
20✔
377
        for inner_type in inner_types:
20✔
378
            try:
20✔
379
                _check_inner_type(argname, value, inner_type)
20✔
380
                return
20✔
381
            except (TypeError, TypeCheckError):
20✔
382
                pass
20✔
383
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
384
    elif outer_type in {list, set}:
20✔
385
        if isinstance(value, outer_type):
20✔
386
            for item in value:
20✔
387
                _check_inner_type(argname, item, inner_types[0])
20✔
388
        else:
389
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
390
    elif outer_type is dict:
20✔
391
        if isinstance(value, dict):
20✔
392
            for key, item in value.items():
20✔
393
                _check_inner_type(argname, key, inner_types[0])
20✔
394
                _check_inner_type(argname, item, inner_types[1])
20✔
395
        else:
396
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
397
    elif outer_type is tuple:
20✔
398
        if isinstance(value, tuple) and len(inner_types) == 2 and inner_types[1] is Ellipsis:
20✔
399
            for item in value:
20✔
400
                _check_inner_type(argname, item, inner_types[0])
20✔
401
        elif isinstance(value, tuple) and len(value) == len(inner_types):
20✔
402
            for item, inner_type in zip(value, inner_types):
20✔
403
                _check_inner_type(argname, item, inner_type)
20✔
404
        else:
405
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
406
    else:
407
        check_type(
20✔
408
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
409
        )
410

411

412
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
20✔
413
    """Returns potential suggestions for the given arg and its annotation"""
414
    try:
20✔
415
        if isinstance(arg, type) and issubclass(arg, annotation):
20✔
416
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
20✔
417
    except TypeError:
20✔
418
        pass
20✔
419

420
    return ""
20✔
421

422

423
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
20✔
424
    @wrapt.decorator
20✔
425
    def wrapper(wrapped, instance, args, kwargs):
20✔
426
        # Create an accumulator to store the instances mutated across this function call.
427
        # Store and restore existing mutated instance lists in case the instance method
428
        # executes another instance method.
429
        instance_klass = type(instance)
20✔
430
        mutated_instances_to_restore = []
20✔
431
        if hasattr(instance_klass, "__mutated_instances__"):
20✔
432
            mutated_instances_to_restore = getattr(instance_klass, "__mutated_instances__")
20✔
433
        setattr(instance_klass, "__mutated_instances__", [])
20✔
434

435
        try:
20✔
436
            r = _check_function_contracts(wrapped, instance, args, kwargs)
20✔
437
            if _instance_init_in_callstack(instance):
20✔
438
                return r
20✔
439
            _check_class_type_annotations(klass, instance)
20✔
440
            klass_mod = _get_module(klass)
20✔
441
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
20✔
442
                _check_invariants(instance, klass, klass_mod.__dict__)
20✔
443

444
                # Additionally check RI violations on PyTA-decorated instances that were mutated
445
                # across the function call.
446
                mutated_instances = getattr(instance_klass, "__mutated_instances__", [])
20✔
447
                for mutable_instance in mutated_instances:
20✔
448
                    _check_invariants(mutable_instance, klass, klass_mod.__dict__)
20✔
449
        except PyTAContractError as e:
20✔
450
            raise AssertionError(str(e)) from None
20✔
451
        else:
452
            return r
20✔
453
        finally:
454
            setattr(instance_klass, "__mutated_instances__", mutated_instances_to_restore)
20✔
455

456
    return wrapper(wrapped)
20✔
457

458

459
def _instance_init_in_callstack(instance: Any) -> bool:
20✔
460
    """Return whether instance's init is part of the current callstack
461

462
    Note: due to the nature of the check, externally defined __init__ functions with
463
    'self' defined as the first parameter may pass this check.
464
    """
465
    frame = inspect.currentframe().f_back
20✔
466
    while frame:
20✔
467
        frame_context_name = inspect.getframeinfo(frame).function
20✔
468
        frame_context_self = frame.f_locals.get("self")
20✔
469
        frame_context_vars = frame.f_code.co_varnames
20✔
470
        if (
20✔
471
            frame_context_name == "__init__"
472
            and frame_context_self is instance
473
            and frame_context_vars[0] == "self"
474
        ):
475
            return True
20✔
476
        frame = frame.f_back
20✔
477
    return False
20✔
478

479

480
def _check_class_type_annotations(klass: type, instance: Any) -> None:
20✔
481
    """Check that the type annotations for the class still hold.
482

483
    Precondition:
484
        - isinstance(instance, klass)
485
    """
486
    klass_mod = _get_module(klass)
20✔
487
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
488

489
    for attr, annotation in cls_annotations.items():
20✔
490
        value = getattr(instance, attr)
20✔
491
        try:
20✔
492
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
493
            check_type(
20✔
494
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
495
            )
496
        except TypeCheckError:
20✔
497
            raise AssertionError(
20✔
498
                f"Value {_display_value(value)} for attribute {attr} did not match expected type "
499
                f"{_display_annotation(annotation)}"
500
            )
501

502

503
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
20✔
504
    """Check that the representation invariants for the instance are satisfied."""
505
    if hasattr(instance, "__pyta_currently_checking"):
20✔
506
        # If already checking invariants for this instance, skip to avoid infinite recursion
507
        return
20✔
508

509
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
20✔
510

511
    rep_invariants = getattr(klass, "__representation_invariants__", set())
20✔
512

513
    try:
20✔
514
        for invariant, compiled in rep_invariants:
20✔
515
            try:
20✔
516
                _debug(
20✔
517
                    "Checking representation invariant for "
518
                    f"{instance.__class__.__qualname__}: {invariant}"
519
                )
520
                check = eval(compiled, {**global_scope, "self": instance})
20✔
521
            except AssertionError as e:
20✔
522
                raise AssertionError(str(e)) from None
20✔
UNCOV
523
            except:
×
UNCOV
524
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
525
            else:
526
                if not check:
20✔
527
                    curr_attributes = ", ".join(
20✔
528
                        f"{k}: {_display_value(v)}"
529
                        for k, v in vars(instance).items()
530
                        if k != "__pyta_currently_checking"
531
                    )
532

533
                    curr_attributes = "{" + curr_attributes + "}"
20✔
534

535
                    raise PyTAContractError(
20✔
536
                        f'{instance.__class__.__name__} representation invariant "{invariant}" was violated for'
537
                        f" instance attributes {curr_attributes}"
538
                    )
539

540
    finally:
541
        delattr(instance, "__pyta_currently_checking")
20✔
542

543

544
def _get_legal_return_val_var_name(var_dict: dict) -> str:
20✔
545
    """
546
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
547
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
548
    postconditions.
549
    """
550
    legal_var_name = "__function_return_value__"
20✔
551

552
    while legal_var_name in var_dict:
20✔
UNCOV
553
        legal_var_name += "_"
×
554

555
    return legal_var_name
20✔
556

557

558
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
20✔
559
    """
560
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
561
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
562

563
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
564
    """
565

566
    if FUNCTION_RETURN_VALUE in assertion:
20✔
567
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
20✔
568
    return assertion
20✔
569

570

571
def _check_assertions(
20✔
572
    wrapped: Callable[..., Any],
573
    function_locals: dict,
574
    condition_type: str = "precondition",
575
    function_return_val: Any = None,
576
) -> None:
577
    """Check that the given assertions are still satisfied."""
578
    # Check bounded function
579
    if hasattr(wrapped, "__self__"):
20✔
580
        target = wrapped.__func__
20✔
581
    else:
582
        target = wrapped
20✔
583
    assertions = []
20✔
584
    if condition_type == "precondition":
20✔
585
        assertions = target.__preconditions__
20✔
586
    elif condition_type == "postcondition":
20✔
587
        assertions = target.__postconditions__
20✔
588
    for assertion_str, compiled, *return_val_var_name in assertions:
20✔
589
        return_val_dict = {}
20✔
590
        if condition_type == "postcondition":
20✔
591
            return_val_dict = {return_val_var_name[0]: function_return_val}
20✔
592
        try:
20✔
593
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
20✔
594
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
20✔
595
        except AssertionError as e:
20✔
596
            raise AssertionError(str(e)) from None
20✔
UNCOV
597
        except:
×
UNCOV
598
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
599
        else:
600
            if not check:
20✔
601
                arg_string = ", ".join(
20✔
602
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
603
                )
604
                arg_string = "{" + arg_string + "}"
20✔
605

606
                return_val_string = ""
20✔
607

608
                if condition_type == "postcondition":
20✔
609
                    return_val_string = f" and return value {function_return_val}"
20✔
610
                raise PyTAContractError(
20✔
611
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
612
                    f"violated for arguments {arg_string}{return_val_string}"
613
                )
614

615

616
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> list[str]:
20✔
617
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
618

619
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
620

621
    Currently only supports two forms:
622

623
    1. A single line of the form "<parse_token>: <cond>"
624
    2. A group of lines starting with "<parse_token>s:", where each subsequent
625
       line is of the form "- <cond>". Each line is considered a separate condition.
626
       The lines can be separated by blank lines, but no other text.
627
    """
628
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
20✔
629
        # Check if obj is an astroid node
630
        docstring = obj.doc_node.value
20✔
631
    else:
632
        docstring = getattr(obj, "__doc__") or ""
20✔
633
    lines = [line.strip() for line in docstring.split("\n")]
20✔
634
    assertion_lines = [
20✔
635
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
636
    ]
637

638
    if assertion_lines == []:
20✔
639
        return []
20✔
640

641
    first = assertion_lines[0]
20✔
642

643
    if lines[first].startswith(parse_token + ":"):
20✔
644
        return [lines[first][len(parse_token + ":") :].strip()]
20✔
645
    elif lines[first].startswith(parse_token + "s:"):
20✔
646
        assertions = []
20✔
647
        for line in lines[first + 1 :]:
20✔
648
            if line.startswith("-"):
20✔
649
                assertion = line[1:].strip()
20✔
650
                if hasattr(obj, "__qualname__"):
20✔
651
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
20✔
652
                assertions.append(assertion)
20✔
653
            elif line != "":
20✔
UNCOV
654
                break
×
655
        return assertions
20✔
656
    else:
UNCOV
657
        return []
×
658

659

660
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
20✔
661
    """Return a human-friendly representation of the given value.
662

663
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
664

665
    Preconditions:
666
        - max_length >= 5
667
    """
668
    s = repr(value)
20✔
669
    if not DEBUG_CONTRACTS and len(s) > max_length:
20✔
UNCOV
670
        i = (max_length - 3) // 2
×
UNCOV
671
        return s[:i] + "..." + s[-i:]
×
672
    else:
673
        return s
20✔
674

675

676
def _display_annotation(annotation: Any) -> str:
20✔
677
    """Return a human-friendly representation of the given type annotation.
678

679
    >>> _display_annotation(int)
680
    'int'
681
    >>> _display_annotation(list[int])
682
    'list[int]'
683
    >>> from typing import List
684
    >>> _display_annotation(List[int])
685
    'typing.List[int]'
686
    """
687
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
20✔
UNCOV
688
        return "None"
×
689
    if hasattr(annotation, "__origin__"):  # Generic type annotations
20✔
690
        return repr(annotation)
20✔
691
    elif hasattr(annotation, "__name__"):
20✔
692
        return annotation.__name__
20✔
693
    else:
UNCOV
694
        return repr(annotation)
×
695

696

697
def _get_module(obj: Any) -> ModuleType:
20✔
698
    """Return the module where obj was defined (normally obj.__module__).
699

700
    NOTE: this function defines a special case when using PyCharm and the file
701
    defining the object is "Run in Python Console". In this case, the pydevd runner
702
    renames the '__main__' module to 'pydev_umd', and so we need to access that
703
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
704
    to False.
705
    """
706
    module_name = obj.__module__
20✔
707
    module = sys.modules[module_name]
20✔
708

709
    if (
20✔
710
        module_name != "__main__"
711
        or not RENAME_MAIN_TO_PYDEV_UMD
712
        or _PYDEV_UMD_NAME not in sys.modules
713
    ):
714
        return module
20✔
715

716
    # Get a function/class name to check whether it is defined in the module
UNCOV
717
    if isinstance(obj, (FunctionType, type)):
×
UNCOV
718
        name = obj.__name__
×
719
    else:
720
        # For any other type of object, be conservative and just return the module
UNCOV
721
        return module
×
722

UNCOV
723
    if name in vars(module):
×
UNCOV
724
        return module
×
725
    else:
UNCOV
726
        return sys.modules[_PYDEV_UMD_NAME]
×
727

728

729
def _debug(msg: str) -> None:
20✔
730
    """Display a debugging message.
731

732
    Do nothing if DEBUG_CONTRACTS is False.
733
    """
734
    if not DEBUG_CONTRACTS:
20✔
735
        return
20✔
736
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
20✔
737
    logging.debug(msg)
20✔
738

739

740
def _set_invariants(klass: type) -> None:
20✔
741
    """Retrieve and set the representation invariants of this class"""
742
    # Update representation invariants from this class' docstring and those of its superclasses.
743
    rep_invariants: list[tuple[str, CodeType]] = []
20✔
744

745
    # Iterate over all inherited classes except builtins
746
    for cls in reversed(klass.__mro__):
20✔
747
        if "__representation_invariants__" in cls.__dict__:
20✔
748
            rep_invariants.extend(cls.__representation_invariants__)
20✔
749
        elif cls.__module__ != "builtins":
20✔
750
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
20✔
751
            # Try compiling assertions
752
            for assertion in assertions:
20✔
753
                try:
20✔
754
                    compiled = compile(assertion, "<string>", "eval")
20✔
UNCOV
755
                except:
×
UNCOV
756
                    _debug(
×
757
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
758
                    )
UNCOV
759
                    continue
×
760
                rep_invariants.append((assertion, compiled))
20✔
761

762
    setattr(klass, "__representation_invariants__", rep_invariants)
20✔
763

764

765
def validate_invariants(obj: object) -> None:
20✔
766
    """Check that the representation invariants of obj are satisfied."""
767
    klass = obj.__class__
20✔
768
    klass_mod = _get_module(klass)
20✔
769

770
    try:
20✔
771
        _check_invariants(obj, klass, klass_mod.__dict__)
20✔
772
    except PyTAContractError as e:
20✔
773
        raise AssertionError(str(e)) from None
20✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc