• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 14159714644

30 Mar 2025 08:31PM UTC coverage: 92.559% (+0.002%) from 92.557%
14159714644

Pull #1164

github

web-flow
Merge cc86cd7f3 into ba8b9c490
Pull Request #1164: Fix issue with inline comments in docstring assertions

1 of 2 new or added lines in 2 files covered. (50.0%)

13 existing lines in 1 file now uncovered.

3346 of 3615 relevant lines covered (92.56%)

17.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.41
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
from __future__ import annotations
20✔
14

15
import inspect
20✔
16
import logging
20✔
17
import sys
20✔
18
import typing
20✔
19
from types import CodeType, FunctionType, ModuleType
20✔
20
from typing import (
20✔
21
    Any,
22
    Callable,
23
    Collection,
24
    Optional,
25
    TypeVar,
26
    Union,
27
    get_args,
28
    get_origin,
29
    overload,
30
)
31

32
import wrapt
20✔
33
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
20✔
34

35
# Configuration options
36

37
ENABLE_CONTRACT_CHECKING = True
20✔
38
"""
16✔
39
Set to True to enable contract checking.
40
"""
41

42
DEBUG_CONTRACTS = False
20✔
43
"""
16✔
44
Set to True to display debugging messages when checking contracts.
45
"""
46

47
RENAME_MAIN_TO_PYDEV_UMD = True
20✔
48
"""
16✔
49
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
50
In most cases you should not need to change this!
51
"""
52

53
STRICT_NUMERIC_TYPES = True
20✔
54
"""
16✔
55
Set to False to allow more specific numeric types to be accepted by more general type annotations.
56
"""
57

58
_PYDEV_UMD_NAME = "pydev_umd"
20✔
59

60

61
_DEFAULT_MAX_VALUE_LENGTH = 30
20✔
62
FUNCTION_RETURN_VALUE = "$return_value"
20✔
63

64

65
class PyTAContractError(Exception):
20✔
66
    """Error raised when a PyTA contract assertion is violated."""
67

68

69
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
20✔
70
    """Automatically check contracts for all functions and classes in the given modules.
71

72
    By default (when called with no arguments), the current module is used.
73

74
    Args:
75
        *mod_names: The names of modules to check contracts for. These modules must have been
76
            previously imported.
77
        decorate_main: True if the module being run (where __name__ == '__main__') should
78
            have contracts checked.
79
    """
80
    if not ENABLE_CONTRACT_CHECKING:
20✔
81
        return
×
82

83
    modules = []
20✔
84
    if decorate_main:
20✔
85
        mod_names = mod_names + ("__main__",)
×
86

87
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
88
        # with the "Run in Python Console" action.
89
        if RENAME_MAIN_TO_PYDEV_UMD:
×
90
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
91

92
    for module_name in mod_names:
20✔
93
        modules.append(sys.modules.get(module_name, None))
20✔
94

95
    for module in modules:
20✔
96
        if not module:
20✔
97
            # Module name was passed in incorrectly.
98
            continue
×
99
        for name, value in inspect.getmembers(module):
20✔
100
            if inspect.isfunction(value) or inspect.isclass(value):
20✔
101
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
20✔
102

103

104
@wrapt.decorator
20✔
105
def _enable_function_contracts(wrapped, instance, args, kwargs):
20✔
106
    """A decorator that enables checking contracts for a function."""
107
    try:
20✔
108
        if instance is not None and inspect.isclass(instance):
20✔
109
            # This is a class method, so there is no instance.
110
            return _check_function_contracts(wrapped, None, args, kwargs)
20✔
111
        else:
112
            return _check_function_contracts(wrapped, instance, args, kwargs)
20✔
113
    except PyTAContractError as e:
20✔
114
        raise AssertionError(str(e)) from None
20✔
115

116

117
# Wildcard Type Variable
118
Class = TypeVar("Class", bound=type)
20✔
119

120

121
@overload
20✔
122
def check_contracts(
20✔
123
    func: FunctionType, module_names: Optional[set[str]] = None
124
) -> FunctionType: ...
125

126

127
@overload
20✔
128
def check_contracts(func: Class, module_names: Optional[set[str]] = None) -> Class: ...
20✔
129

130

131
def check_contracts(
20✔
132
    func_or_class: Union[Class, FunctionType], module_names: Optional[set[str]] = None
133
) -> Union[Class, FunctionType]:
134
    """A decorator to enable contract checking for a function or class.
135

136
    When used with a class, all methods defined within the class have contract checking enabled.
137
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
138

139
    Example:
140

141
        >>> from python_ta.contracts import check_contracts
142
        >>> @check_contracts
143
        ... def divide(x: int, y: int) -> int:
144
        ...     \"\"\"Return x // y.
145
        ...
146
        ...     Preconditions:
147
        ...        - y != 0
148
        ...     \"\"\"
149
        ...     return x // y
150
    """
151
    if not ENABLE_CONTRACT_CHECKING:
20✔
152
        return func_or_class
20✔
153

154
    if module_names is not None and func_or_class.__module__ not in module_names:
20✔
155
        _debug(
20✔
156
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
157
        )
158
        return func_or_class
20✔
159
    elif inspect.isroutine(func_or_class):
20✔
160
        return _enable_function_contracts(func_or_class)
20✔
161
    elif inspect.isclass(func_or_class):
20✔
162
        add_class_invariants(func_or_class)
20✔
163
        return func_or_class
20✔
164
    else:
165
        # Default action
166
        return func_or_class
×
167

168

169
def add_class_invariants(klass: type) -> None:
20✔
170
    """Modify the given class to check representation invariants and method contracts."""
171
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
20✔
172
        # This means the class has already been decorated
173
        return
×
174

175
    _set_invariants(klass)
20✔
176

177
    klass_mod = _get_module(klass)
20✔
178
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
20✔
179

180
    def new_setattr(self: klass, name: str, value: Any) -> None:
20✔
181
        """Set the value of the given attribute on self to the given value.
182

183
        Check representation invariants for this class when not within an instance method of the class.
184
        """
185
        if not ENABLE_CONTRACT_CHECKING:
20✔
186
            super(klass, self).__setattr__(name, value)
20✔
187
            return
20✔
188

189
        nonlocal cls_annotations
190
        if cls_annotations is None:
20✔
191
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
192

193
        if name in cls_annotations:
20✔
194
            try:
20✔
195
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
196
                check_type(
20✔
197
                    value,
198
                    cls_annotations[name],
199
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
200
                )
201
            except TypeCheckError:
20✔
202
                raise AssertionError(
20✔
203
                    f"Value {_display_value(value)} for attribute {name} did not match expected type "
204
                    f"{_display_annotation(cls_annotations[name])}"
205
                ) from None
206
        original_attr_value_exists = False
20✔
207
        original_attr_value = None
20✔
208
        if hasattr(self, name):
20✔
209
            original_attr_value_exists = True
20✔
210
            original_attr_value = super(klass, self).__getattribute__(name)
20✔
211
        super(klass, self).__setattr__(name, value)
20✔
212
        frame_locals = inspect.currentframe().f_back.f_locals
20✔
213
        caller_self = frame_locals.get("self")
20✔
214
        if not isinstance(caller_self, type(self)):
20✔
215
            # Only validating if the attribute is not being set in a instance/class method
216
            # AND caller_self is an instance of self's type
217
            if klass_mod is not None:
20✔
218
                try:
20✔
219
                    _check_invariants(self, klass, klass_mod.__dict__)
20✔
220
                except PyTAContractError as e:
20✔
221
                    if original_attr_value_exists:
20✔
222
                        super(klass, self).__setattr__(name, original_attr_value)
20✔
223
                    else:
224
                        super(klass, self).__delattr__(name)
20✔
225
                    raise AssertionError(str(e)) from None
20✔
226
        elif caller_self is not self:
20✔
227
            # Keep track of mutations to instances that are of the same type as caller_self (and are also not `self`)
228
            # to enforce RIs on them only after the caller function returns.
229
            caller_klass = type(caller_self)
20✔
230
            if hasattr(caller_klass, "__mutated_instances__"):
20✔
231
                mutated_instances = getattr(caller_klass, "__mutated_instances__")
20✔
232
                if self not in mutated_instances:
20✔
233
                    mutated_instances.append(self)
20✔
234

235
    for attr, value in klass.__dict__.items():
20✔
236
        if inspect.isroutine(value):
20✔
237
            if isinstance(value, (staticmethod, classmethod)):
20✔
238
                # Don't check rep invariants for staticmethod and classmethod
239
                setattr(klass, attr, check_contracts(value))
20✔
240
            else:
241
                setattr(klass, attr, _instance_method_wrapper(value, klass))
20✔
242

243
    klass.__setattr__ = new_setattr
20✔
244

245

246
def _check_function_contracts(wrapped, instance, args, kwargs):
20✔
247
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
20✔
248
    if instance is not None:
20✔
249
        klass_mod = _get_module(type(instance))
20✔
250
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
20✔
251
    else:
252
        annotations = typing.get_type_hints(wrapped)
20✔
253
    args_with_self = args if instance is None else (instance,) + args
20✔
254

255
    # Check function parameter types
256
    for arg, param in zip(args_with_self, params):
20✔
257
        if param in annotations:
20✔
258
            try:
20✔
259
                _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
20✔
260
                if STRICT_NUMERIC_TYPES:
20✔
261
                    check_type_strict(param, arg, annotations[param])
20✔
262
                else:
263
                    check_type(arg, annotations[param])
20✔
264
            except (TypeError, TypeCheckError):
20✔
265
                additional_suggestions = _get_argument_suggestions(arg, annotations[param])
20✔
266

267
                raise PyTAContractError(
20✔
268
                    f"Argument value {_display_value(arg)} for {wrapped.__name__} parameter {param} "
269
                    f"did not match expected type {_display_annotation(annotations[param])}"
270
                    + (f"\n{additional_suggestions}" if additional_suggestions else "")
271
                )
272

273
    function_locals = dict(zip(params, args_with_self))
20✔
274

275
    # Check bounded function
276
    if hasattr(wrapped, "__self__"):
20✔
277
        target = wrapped.__func__
20✔
278
    else:
279
        target = wrapped
20✔
280

281
    # Check function preconditions
282
    if not hasattr(target, "__preconditions__"):
20✔
283
        target.__preconditions__: list[tuple[str, CodeType]] = []
20✔
284
        preconditions = parse_assertions(wrapped)
20✔
285
        for precondition in preconditions:
20✔
286
            try:
20✔
287
                compiled = compile(precondition, "<string>", "eval")
20✔
288
            except:
20✔
289
                _debug(
20✔
290
                    f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
291
                )
292
                continue
20✔
293
            target.__preconditions__.append((precondition, compiled))
20✔
294

295
    if ENABLE_CONTRACT_CHECKING:
20✔
296
        _check_assertions(wrapped, function_locals)
20✔
297

298
    # Check return type
299
    r = wrapped(*args, **kwargs)
20✔
300
    if "return" in annotations:
20✔
301
        return_type = annotations["return"]
20✔
302
        try:
20✔
303
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
20✔
304
            if STRICT_NUMERIC_TYPES:
20✔
305
                check_type_strict("return", r, return_type)
20✔
306
            else:
307
                check_type(r, return_type)
20✔
308
        except (TypeError, TypeCheckError):
20✔
309
            raise PyTAContractError(
20✔
310
                f"Return value {_display_value(r)} for {wrapped.__name__} did not match "
311
                f"expected type {_display_annotation(return_type)}"
312
            )
313

314
    # Check function postconditions
315
    if not hasattr(target, "__postconditions__"):
20✔
316
        target.__postconditions__: list[tuple[str, CodeType, str]] = []
20✔
317
        return_val_var_name = _get_legal_return_val_var_name(
20✔
318
            {**wrapped.__globals__, **function_locals}
319
        )
320
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
20✔
321
        for postcondition in postconditions:
20✔
322
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
20✔
323
            try:
20✔
324
                compiled = compile(assertion, "<string>", "eval")
20✔
325
            except:
×
326
                _debug(
×
327
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
328
                )
329
                continue
×
330
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
20✔
331

332
    if ENABLE_CONTRACT_CHECKING:
20✔
333
        _check_assertions(
20✔
334
            wrapped,
335
            function_locals,
336
            function_return_val=r,
337
            condition_type="postcondition",
338
        )
339

340
    return r
20✔
341

342

343
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
20✔
344
    """Ensure that `value` matches ``expected_type`` with strict type checking.
345

346
    This function enforces strict type distinctions within the numeric hierarchy (bool, int, float,
347
    complex), ensuring that the type of value is exactly the same as expected_type.
348
    """
349
    if not ENABLE_CONTRACT_CHECKING:
20✔
350
        return
20✔
351
    try:
20✔
352
        _check_inner_type(argname, value, expected_type)
20✔
353
    except (TypeError, TypeCheckError):
20✔
354
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
355

356

357
def _check_inner_type(argname: str, value: Any, expected_type: type) -> None:
20✔
358
    """Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
359
    collections (list[int], dicts[float]) and Union types (bool | int).
360
    """
361
    inner_types = get_args(expected_type)
20✔
362
    outer_type = get_origin(expected_type)
20✔
363
    if outer_type is None:
20✔
364
        if (
20✔
365
            (type(value) is bool and expected_type in {int, float, complex})
366
            or (type(value) is int and expected_type in {float, complex})
367
            or (type(value) is float and expected_type is complex)
368
        ):
369
            raise TypeError(
20✔
370
                f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
371
            )
372
        else:
373
            check_type(
20✔
374
                value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
375
            )
376
    elif outer_type is typing.Union:
20✔
377
        for inner_type in inner_types:
20✔
378
            try:
20✔
379
                _check_inner_type(argname, value, inner_type)
20✔
380
                return
20✔
381
            except (TypeError, TypeCheckError):
20✔
382
                pass
20✔
383
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
384
    elif outer_type in {list, set}:
20✔
385
        if isinstance(value, outer_type):
20✔
386
            for item in value:
20✔
387
                _check_inner_type(argname, item, inner_types[0])
20✔
388
        else:
389
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
390
    elif outer_type is dict:
20✔
391
        if isinstance(value, dict):
20✔
392
            for key, item in value.items():
20✔
393
                _check_inner_type(argname, key, inner_types[0])
20✔
394
                _check_inner_type(argname, item, inner_types[1])
20✔
395
        else:
396
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
397
    elif outer_type is tuple:
20✔
398
        if isinstance(value, tuple) and len(inner_types) == 2 and inner_types[1] is Ellipsis:
20✔
399
            for item in value:
20✔
400
                _check_inner_type(argname, item, inner_types[0])
20✔
401
        elif isinstance(value, tuple) and len(value) == len(inner_types):
20✔
402
            for item, inner_type in zip(value, inner_types):
20✔
403
                _check_inner_type(argname, item, inner_type)
20✔
404
        else:
405
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
406
    else:
407
        check_type(
20✔
408
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
409
        )
410

411

412
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
20✔
413
    """Returns potential suggestions for the given arg and its annotation"""
414
    try:
20✔
415
        if isinstance(arg, type) and issubclass(arg, annotation):
20✔
416
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
20✔
417
    except TypeError:
20✔
418
        pass
20✔
419

420
    return ""
20✔
421

422

423
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
20✔
424
    @wrapt.decorator
20✔
425
    def wrapper(wrapped, instance, args, kwargs):
20✔
426
        # Create an accumulator to store the instances mutated across this function call.
427
        # Store and restore existing mutated instance lists in case the instance method
428
        # executes another instance method.
429
        instance_klass = type(instance)
20✔
430
        mutated_instances_to_restore = None
20✔
431
        if hasattr(instance_klass, "__mutated_instances__"):
20✔
432
            mutated_instances_to_restore = getattr(instance_klass, "__mutated_instances__")
20✔
433
        setattr(instance_klass, "__mutated_instances__", [])
20✔
434

435
        try:
20✔
436
            r = _check_function_contracts(wrapped, instance, args, kwargs)
20✔
437
            if _instance_init_in_callstack(instance):
20✔
438
                return r
20✔
439
            _check_class_type_annotations(klass, instance)
20✔
440
            klass_mod = _get_module(klass)
20✔
441
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
20✔
442
                _check_invariants(instance, klass, klass_mod.__dict__)
20✔
443

444
                # Additionally check RI violations on PyTA-decorated instances that were mutated
445
                # across the function call.
446
                mutated_instances = getattr(instance_klass, "__mutated_instances__", [])
20✔
447
                for mutated_instance in mutated_instances:
20✔
448
                    # Mutated instances may be of parent class types so the invariants to check should also be
449
                    # for the parent class and not the child class.
450
                    mutated_instance_klass = type(mutated_instance)
20✔
451
                    mutated_instance_klass_mod = _get_module(mutated_instance_klass)
20✔
452
                    _check_invariants(
20✔
453
                        mutated_instance,
454
                        mutated_instance_klass,
455
                        mutated_instance_klass_mod.__dict__,
456
                    )
457
        except PyTAContractError as e:
20✔
458
            raise AssertionError(str(e)) from None
20✔
459
        else:
460
            return r
20✔
461
        finally:
462
            if mutated_instances_to_restore is None:
20✔
463
                delattr(instance_klass, "__mutated_instances__")
20✔
464
            else:
465
                setattr(instance_klass, "__mutated_instances__", mutated_instances_to_restore)
20✔
466

467
    return wrapper(wrapped)
20✔
468

469

470
def _instance_init_in_callstack(instance: Any) -> bool:
20✔
471
    """Return whether instance's init is part of the current callstack
472

473
    Note: due to the nature of the check, externally defined __init__ functions with
474
    'self' defined as the first parameter may pass this check.
475
    """
476
    frame = inspect.currentframe().f_back
20✔
477
    while frame:
20✔
478
        frame_context_name = inspect.getframeinfo(frame).function
20✔
479
        frame_context_self = frame.f_locals.get("self")
20✔
480
        frame_context_vars = frame.f_code.co_varnames
20✔
481
        if (
20✔
482
            frame_context_name == "__init__"
483
            and frame_context_self is instance
484
            and frame_context_vars[0] == "self"
485
        ):
486
            return True
20✔
487
        frame = frame.f_back
20✔
488
    return False
20✔
489

490

491
def _check_class_type_annotations(klass: type, instance: Any) -> None:
20✔
492
    """Check that the type annotations for the class still hold.
493

494
    Precondition:
495
        - isinstance(instance, klass)
496
    """
497
    klass_mod = _get_module(klass)
20✔
498
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
499

500
    for attr, annotation in cls_annotations.items():
20✔
501
        value = getattr(instance, attr)
20✔
502
        try:
20✔
503
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
504
            check_type(
20✔
505
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
506
            )
507
        except TypeCheckError:
20✔
508
            raise AssertionError(
20✔
509
                f"Value {_display_value(value)} for attribute {attr} did not match expected type "
510
                f"{_display_annotation(annotation)}"
511
            )
512

513

514
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
20✔
515
    """Check that the representation invariants for the instance are satisfied."""
516
    if hasattr(instance, "__pyta_currently_checking"):
20✔
517
        # If already checking invariants for this instance, skip to avoid infinite recursion
518
        return
20✔
519

520
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
20✔
521

522
    rep_invariants = getattr(klass, "__representation_invariants__", set())
20✔
523

524
    try:
20✔
525
        for invariant, compiled in rep_invariants:
20✔
526
            try:
20✔
527
                _debug(
20✔
528
                    "Checking representation invariant for "
529
                    f"{instance.__class__.__qualname__}: {invariant}"
530
                )
531
                check = eval(compiled, {**global_scope, "self": instance})
20✔
532
            except AssertionError as e:
20✔
533
                raise AssertionError(str(e)) from None
20✔
534
            except:
×
535
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
536
            else:
537
                if not check:
20✔
538
                    curr_attributes = ", ".join(
20✔
539
                        f"{k}: {_display_value(v)}"
540
                        for k, v in vars(instance).items()
541
                        if k != "__pyta_currently_checking"
542
                    )
543

544
                    curr_attributes = "{" + curr_attributes + "}"
20✔
545

546
                    raise PyTAContractError(
20✔
547
                        f'{instance.__class__.__name__} representation invariant "{invariant}" was violated for'
548
                        f" instance attributes {curr_attributes}"
549
                    )
550

551
    finally:
552
        delattr(instance, "__pyta_currently_checking")
20✔
553

554

555
def _get_legal_return_val_var_name(var_dict: dict) -> str:
20✔
556
    """
557
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
558
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
559
    postconditions.
560
    """
561
    legal_var_name = "__function_return_value__"
20✔
562

563
    while legal_var_name in var_dict:
20✔
564
        legal_var_name += "_"
×
565

566
    return legal_var_name
20✔
567

568

569
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
20✔
570
    """
571
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
572
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
573

574
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
575
    """
576

577
    if FUNCTION_RETURN_VALUE in assertion:
20✔
578
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
20✔
579
    return assertion
20✔
580

581

582
def _check_assertions(
20✔
583
    wrapped: Callable[..., Any],
584
    function_locals: dict,
585
    condition_type: str = "precondition",
586
    function_return_val: Any = None,
587
) -> None:
588
    """Check that the given assertions are still satisfied."""
589
    # Check bounded function
590
    if hasattr(wrapped, "__self__"):
20✔
591
        target = wrapped.__func__
20✔
592
    else:
593
        target = wrapped
20✔
594
    assertions = []
20✔
595
    if condition_type == "precondition":
20✔
596
        assertions = target.__preconditions__
20✔
597
    elif condition_type == "postcondition":
20✔
598
        assertions = target.__postconditions__
20✔
599
    for assertion_str, compiled, *return_val_var_name in assertions:
20✔
600
        return_val_dict = {}
20✔
601
        if condition_type == "postcondition":
20✔
602
            return_val_dict = {return_val_var_name[0]: function_return_val}
20✔
603
        try:
20✔
604
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
20✔
605
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
20✔
606
        except AssertionError as e:
20✔
607
            raise AssertionError(str(e)) from None
20✔
608
        except:
×
609
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
610
        else:
611
            if not check:
20✔
612
                arg_string = ", ".join(
20✔
613
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
614
                )
615
                arg_string = "{" + arg_string + "}"
20✔
616

617
                return_val_string = ""
20✔
618

619
                if condition_type == "postcondition":
20✔
620
                    return_val_string = f" and return value {function_return_val}"
20✔
621
                raise PyTAContractError(
20✔
622
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
623
                    f"violated for arguments {arg_string}{return_val_string}"
624
                )
625

626

627
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> list[str]:
20✔
628
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
629

630
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
631

632
    Currently only supports two forms:
633

634
    1. A single line of the form "<parse_token>: <cond>"
635
    2. A group of lines starting with "<parse_token>s:", where each subsequent
636
       line is of the form "- <cond>". Each line is considered a separate condition.
637
       The lines can be separated by blank lines, but no other text.
638
    """
639
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
20✔
640
        # Check if obj is an astroid node
641
        docstring = obj.doc_node.value
20✔
642
    else:
643
        docstring = getattr(obj, "__doc__") or ""
20✔
644
    lines = [line.strip() for line in docstring.split("\n")]
20✔
645
    assertion_lines = [
20✔
646
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
647
    ]
648

649
    if assertion_lines == []:
20✔
650
        return []
20✔
651

652
    first = assertion_lines[0]
20✔
653

654
    if lines[first].startswith(parse_token + ":"):
20✔
655
        return [lines[first][len(parse_token + ":") :].strip()]
20✔
656
    elif lines[first].startswith(parse_token + "s:"):
20✔
657
        assertions = []
20✔
658
        for line in lines[first + 1 :]:
20✔
659
            if line.startswith("-"):
20✔
660
                assertion = line[1:].strip()
20✔
661
                # Strip comments from line
662
                assertion = assertion.split("#")[0]
20✔
663
                if hasattr(obj, "__qualname__"):
20✔
664
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
20✔
665
                assertions.append(assertion)
20✔
666
            elif line != "":
20✔
UNCOV
667
                break
×
668
        return assertions
20✔
669
    else:
UNCOV
670
        return []
×
671

672

673
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
20✔
674
    """Return a human-friendly representation of the given value.
675

676
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
677

678
    Preconditions:
679
        - max_length >= 5
680
    """
681
    s = repr(value)
20✔
682
    if not DEBUG_CONTRACTS and len(s) > max_length:
20✔
UNCOV
683
        i = (max_length - 3) // 2
×
UNCOV
684
        return s[:i] + "..." + s[-i:]
×
685
    else:
686
        return s
20✔
687

688

689
def _display_annotation(annotation: Any) -> str:
20✔
690
    """Return a human-friendly representation of the given type annotation.
691

692
    >>> _display_annotation(int)
693
    'int'
694
    >>> _display_annotation(list[int])
695
    'list[int]'
696
    >>> from typing import List
697
    >>> _display_annotation(List[int])
698
    'typing.List[int]'
699
    """
700
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
20✔
UNCOV
701
        return "None"
×
702
    if hasattr(annotation, "__origin__"):  # Generic type annotations
20✔
703
        return repr(annotation)
20✔
704
    elif hasattr(annotation, "__name__"):
20✔
705
        return annotation.__name__
20✔
706
    else:
UNCOV
707
        return repr(annotation)
×
708

709

710
def _get_module(obj: Any) -> ModuleType:
20✔
711
    """Return the module where obj was defined (normally obj.__module__).
712

713
    NOTE: this function defines a special case when using PyCharm and the file
714
    defining the object is "Run in Python Console". In this case, the pydevd runner
715
    renames the '__main__' module to 'pydev_umd', and so we need to access that
716
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
717
    to False.
718
    """
719
    module_name = obj.__module__
20✔
720
    module = sys.modules[module_name]
20✔
721

722
    if (
20✔
723
        module_name != "__main__"
724
        or not RENAME_MAIN_TO_PYDEV_UMD
725
        or _PYDEV_UMD_NAME not in sys.modules
726
    ):
727
        return module
20✔
728

729
    # Get a function/class name to check whether it is defined in the module
UNCOV
730
    if isinstance(obj, (FunctionType, type)):
×
UNCOV
731
        name = obj.__name__
×
732
    else:
733
        # For any other type of object, be conservative and just return the module
734
        return module
×
735

UNCOV
736
    if name in vars(module):
×
737
        return module
×
738
    else:
UNCOV
739
        return sys.modules[_PYDEV_UMD_NAME]
×
740

741

742
def _debug(msg: str) -> None:
20✔
743
    """Display a debugging message.
744

745
    Do nothing if DEBUG_CONTRACTS is False.
746
    """
747
    if not DEBUG_CONTRACTS:
20✔
748
        return
20✔
749
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
20✔
750
    logging.debug(msg)
20✔
751

752

753
def _set_invariants(klass: type) -> None:
20✔
754
    """Retrieve and set the representation invariants of this class"""
755
    # Update representation invariants from this class' docstring and those of its superclasses.
756
    rep_invariants: list[tuple[str, CodeType]] = []
20✔
757

758
    # Iterate over all inherited classes except builtins
759
    for cls in reversed(klass.__mro__):
20✔
760
        if "__representation_invariants__" in cls.__dict__:
20✔
761
            rep_invariants.extend(cls.__representation_invariants__)
20✔
762
        elif cls.__module__ != "builtins":
20✔
763
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
20✔
764
            # Try compiling assertions
765
            for assertion in assertions:
20✔
766
                try:
20✔
767
                    compiled = compile(assertion, "<string>", "eval")
20✔
UNCOV
768
                except:
×
UNCOV
769
                    _debug(
×
770
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
771
                    )
UNCOV
772
                    continue
×
773
                rep_invariants.append((assertion, compiled))
20✔
774

775
    setattr(klass, "__representation_invariants__", rep_invariants)
20✔
776

777

778
def validate_invariants(obj: object) -> None:
20✔
779
    """Check that the representation invariants of obj are satisfied."""
780
    klass = obj.__class__
20✔
781
    klass_mod = _get_module(klass)
20✔
782

783
    try:
20✔
784
        _check_invariants(obj, klass, klass_mod.__dict__)
20✔
785
    except PyTAContractError as e:
20✔
786
        raise AssertionError(str(e)) from None
20✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc