• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 20014220265

08 Dec 2025 01:57AM UTC coverage: 93.843% (-0.1%) from 93.944%
20014220265

Pull #1268

github

web-flow
Merge b5949d480 into 491cb20a5
Pull Request #1268: Updated to pylint and astroid v4.0 and added support for Python 3.14

17 of 17 new or added lines in 5 files covered. (100.0%)

4 existing lines in 2 files now uncovered.

3536 of 3768 relevant lines covered (93.84%)

17.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.85
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
from __future__ import annotations
20✔
14

15
import inspect
20✔
16
import logging
20✔
17
import re
20✔
18
import sys
20✔
19
import typing
20✔
20
from types import CodeType, FunctionType, ModuleType
20✔
21
from typing import (
20✔
22
    Any,
23
    Callable,
24
    Optional,
25
    TypeVar,
26
    Union,
27
    get_args,
28
    get_origin,
29
    overload,
30
)
31

32
import wrapt
20✔
33
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
20✔
34

35
# Configuration options
36

37
ENABLE_CONTRACT_CHECKING = True
20✔
38
"""
20✔
39
Set to True to enable contract checking.
40
"""
41

42
DEBUG_CONTRACTS = False
20✔
43
"""
20✔
44
Set to True to display debugging messages when checking contracts.
45
"""
46

47
RENAME_MAIN_TO_PYDEV_UMD = True
20✔
48
"""
20✔
49
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
50
In most cases you should not need to change this!
51
"""
52

53
STRICT_NUMERIC_TYPES = True
20✔
54
"""
20✔
55
Set to False to allow more specific numeric types to be accepted by more general type annotations.
56
"""
57

58
_PYDEV_UMD_NAME = "pydev_umd"
20✔
59

60

61
_DEFAULT_MAX_VALUE_LENGTH = 30
20✔
62
FUNCTION_RETURN_VALUE = "$return_value"
20✔
63

64

65
class PyTAContractError(Exception):
20✔
66
    """Error raised when a PyTA contract assertion is violated."""
67

68

69
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
20✔
70
    """Automatically check contracts for all functions and classes in the given modules.
71

72
    By default (when called with no arguments), the current module is used.
73

74
    Args:
75
        *mod_names: The names of modules to check contracts for. These modules must have been
76
            previously imported.
77
        decorate_main: True if the module being run (where __name__ == '__main__') should
78
            have contracts checked.
79
    """
80
    if not ENABLE_CONTRACT_CHECKING:
20✔
81
        return
×
82

83
    modules = []
20✔
84
    if decorate_main:
20✔
85
        mod_names = mod_names + ("__main__",)
×
86

87
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
88
        # with the "Run in Python Console" action.
89
        if RENAME_MAIN_TO_PYDEV_UMD:
×
90
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
91

92
    for module_name in mod_names:
20✔
93
        modules.append(sys.modules.get(module_name, None))
20✔
94

95
    for module in modules:
20✔
96
        if not module:
20✔
97
            # Module name was passed in incorrectly.
98
            continue
×
99
        for name, value in inspect.getmembers(module):
20✔
100
            if inspect.isfunction(value) or inspect.isclass(value):
20✔
101
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
20✔
102

103

104
# Wildcard Type Variable
105
Class = TypeVar("Class", bound=type)
20✔
106

107

108
@overload
109
def check_contracts(
110
    func: FunctionType,
111
    module_names: Optional[set[str]] = None,
112
    argument_types: bool = True,
113
    return_type: bool = True,
114
    preconditions: bool = True,
115
    postconditions: bool = True,
116
) -> FunctionType: ...
117

118

119
@overload
120
def check_contracts(
121
    func: Class,
122
    module_names: Optional[set[str]] = None,
123
    argument_types: bool = True,
124
    return_type: bool = True,
125
    preconditions: bool = True,
126
    postconditions: bool = True,
127
) -> Class: ...
128

129

130
def check_contracts(
20✔
131
    func_or_class: Union[Class, FunctionType] = None,
132
    *,
133
    module_names: Optional[set[str]] = None,
134
    argument_types: bool = True,
135
    return_type: bool = True,
136
    preconditions: bool = True,
137
    postconditions: bool = True,
138
) -> Union[Class, FunctionType]:
139
    """A decorator to enable contract checking for a function or class.
140

141
    When used with a class, all methods defined within the class have contract checking enabled.
142
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
143

144
    When used with functions, `check_contracts` accepts four optional boolean keyword arguments to selectively disable checks when set to `False`:
145

146
    - `argument_types`: check parameter type annotations
147
    - `return_type`: check the return type annotation
148
    - `preconditions`: check preconditions
149
    - `postconditions`: check postconditions
150

151
    By default, all four checks are enabled. These arguments only affect functions, and are ignored when `check_contracts` is applied to a class.
152

153
    Example:
154
        >>> from python_ta.contracts import check_contracts
155
        >>> @check_contracts
156
        ... def divide(x: int, y: int) -> int:
157
        ...     \"\"\"Return x // y.
158
        ...
159
        ...     Preconditions:
160
        ...        - y != 0
161
        ...     \"\"\"
162
        ...     return x // y
163
    """
164

165
    @wrapt.decorator
20✔
166
    def _enable_function_contracts(wrapped, instance, args, kwargs):
20✔
167
        """A decorator that enables checking contracts for a function."""
168
        try:
20✔
169
            if instance is not None and inspect.isclass(instance):
20✔
170
                # This is a class method, so there is no instance.
171
                return _check_function_contracts(
20✔
172
                    wrapped,
173
                    None,
174
                    args,
175
                    kwargs,
176
                    argument_types_enabled=argument_types,
177
                    return_type_enabled=return_type,
178
                    preconditions_enabled=preconditions,
179
                    postconditions_enabled=postconditions,
180
                )
181
            else:
182
                return _check_function_contracts(
20✔
183
                    wrapped,
184
                    instance,
185
                    args,
186
                    kwargs,
187
                    argument_types_enabled=argument_types,
188
                    return_type_enabled=return_type,
189
                    preconditions_enabled=preconditions,
190
                    postconditions_enabled=postconditions,
191
                )
192
        except PyTAContractError as e:
20✔
193
            raise AssertionError(str(e)) from None
20✔
194

195
    # Optional Arguments passed to the decorator
196
    if func_or_class is None:
20✔
197
        return wrapt.PartialCallableObjectProxy(
20✔
198
            check_contracts,
199
            module_names=module_names,
200
            argument_types=argument_types,
201
            return_type=return_type,
202
            preconditions=preconditions,
203
            postconditions=postconditions,
204
        )
205

206
    if not ENABLE_CONTRACT_CHECKING:
20✔
207
        return func_or_class
20✔
208

209
    if module_names is not None and func_or_class.__module__ not in module_names:
20✔
210
        _debug(
20✔
211
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
212
        )
213
        return func_or_class
20✔
214
    elif inspect.isroutine(func_or_class):
20✔
215
        return _enable_function_contracts(func_or_class)
20✔
216
    elif inspect.isclass(func_or_class):
20✔
217
        add_class_invariants(func_or_class)
20✔
218
        return func_or_class
20✔
219
    else:
220
        # Default action
221
        return func_or_class
×
222

223

224
def add_class_invariants(klass: type) -> None:
20✔
225
    """Modify the given class to check representation invariants and method contracts."""
226
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
20✔
227
        # This means the class has already been decorated
228
        return
×
229

230
    _set_invariants(klass)
20✔
231

232
    klass_mod = _get_module(klass)
20✔
233
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
20✔
234

235
    def new_setattr(self: klass, name: str, value: Any) -> None:
20✔
236
        """Set the value of the given attribute on self to the given value.
237

238
        Check representation invariants for this class when not within an instance method of the class.
239
        """
240
        if not ENABLE_CONTRACT_CHECKING:
20✔
241
            super(klass, self).__setattr__(name, value)
20✔
242
            return
20✔
243

244
        nonlocal cls_annotations
245
        if cls_annotations is None:
20✔
246
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
247

248
        if name in cls_annotations:
20✔
249
            try:
20✔
250
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
251
                check_type(
20✔
252
                    value,
253
                    cls_annotations[name],
254
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
255
                )
256
            except TypeCheckError:
20✔
257
                raise AssertionError(
20✔
258
                    f"Value {_display_value(value)} for attribute {name} did not match expected type "
259
                    f"{_display_annotation(cls_annotations[name])}"
260
                ) from None
261
        original_attr_value_exists = False
20✔
262
        original_attr_value = None
20✔
263
        if hasattr(self, name):
20✔
264
            original_attr_value_exists = True
20✔
265
            original_attr_value = super(klass, self).__getattribute__(name)
20✔
266
        super(klass, self).__setattr__(name, value)
20✔
267
        frame_locals = inspect.currentframe().f_back.f_locals
20✔
268
        caller_self = frame_locals.get("self")
20✔
269
        if not isinstance(caller_self, type(self)):
20✔
270
            # Only validating if the attribute is not being set in a instance/class method
271
            # AND caller_self is an instance of self's type
272
            if klass_mod is not None:
20✔
273
                try:
20✔
274
                    _check_invariants(self, klass, klass_mod.__dict__)
20✔
275
                except PyTAContractError as e:
20✔
276
                    if original_attr_value_exists:
20✔
277
                        super(klass, self).__setattr__(name, original_attr_value)
20✔
278
                    else:
279
                        super(klass, self).__delattr__(name)
20✔
280
                    raise AssertionError(str(e)) from None
20✔
281
        elif caller_self is not self:
20✔
282
            # Keep track of mutations to instances that are of the same type as caller_self (and are also not `self`)
283
            # to enforce RIs on them only after the caller function returns.
284
            caller_klass = type(caller_self)
20✔
285
            if hasattr(caller_klass, "__mutated_instances__"):
20✔
286
                mutated_instances = getattr(caller_klass, "__mutated_instances__")
20✔
287
                if self not in mutated_instances:
20✔
288
                    mutated_instances.append(self)
20✔
289

290
    for attr, value in klass.__dict__.items():
20✔
291
        # Skip built-in __annotate_func__, which was introduced in Python 3.14
292
        if attr == "__annotate_func__":
20✔
293
            continue
4✔
294
        if inspect.isroutine(value):
20✔
295
            if isinstance(value, (staticmethod, classmethod)):
20✔
296
                # Don't check rep invariants for staticmethod and classmethod
297
                setattr(klass, attr, check_contracts(value))
20✔
298
            else:
299
                setattr(klass, attr, _instance_method_wrapper(value, klass))
20✔
300

301
    klass.__setattr__ = new_setattr
20✔
302

303

304
def _check_function_contracts(
20✔
305
    wrapped,
306
    instance,
307
    args,
308
    kwargs,
309
    argument_types_enabled: bool = True,
310
    return_type_enabled: bool = True,
311
    preconditions_enabled: bool = True,
312
    postconditions_enabled: bool = True,
313
):
314
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
20✔
315
    if instance is not None:
20✔
316
        klass_mod = _get_module(type(instance))
20✔
317
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
20✔
318
    else:
319
        annotations = typing.get_type_hints(wrapped)
20✔
320
    args_with_self = args if instance is None else (instance,) + args
20✔
321

322
    if argument_types_enabled:
20✔
323
        # Check function parameter types
324
        for arg, param in zip(args_with_self, params):
20✔
325
            if param in annotations:
20✔
326
                try:
20✔
327
                    _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
20✔
328
                    if STRICT_NUMERIC_TYPES:
20✔
329
                        check_type_strict(param, arg, annotations[param])
20✔
330
                    else:
331
                        check_type(arg, annotations[param])
20✔
332
                except (TypeError, TypeCheckError):
20✔
333
                    additional_suggestions = _get_argument_suggestions(arg, annotations[param])
20✔
334

335
                    raise PyTAContractError(
20✔
336
                        f"Argument value {_display_value(arg)} for {wrapped.__name__} parameter {param} "
337
                        f"did not match expected type {_display_annotation(annotations[param])}"
338
                        + (f"\n{additional_suggestions}" if additional_suggestions else "")
339
                    )
340

341
    function_locals = dict(zip(params, args_with_self))
20✔
342

343
    # Check bounded function
344
    if hasattr(wrapped, "__self__"):
20✔
345
        target = wrapped.__func__
20✔
346
    else:
347
        target = wrapped
20✔
348

349
    # Check function preconditions
350
    if not hasattr(target, "__preconditions__") and preconditions_enabled:
20✔
351
        target.__preconditions__: list[tuple[str, CodeType]] = []
20✔
352
        preconditions = parse_assertions(wrapped)
20✔
353
        for precondition in preconditions:
20✔
354
            try:
20✔
355
                compiled = compile(precondition, "<string>", "eval")
20✔
356
            except:
20✔
357
                _debug(
20✔
358
                    f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
359
                )
360
                continue
20✔
361
            target.__preconditions__.append((precondition, compiled))
20✔
362

363
    if ENABLE_CONTRACT_CHECKING and preconditions_enabled:
20✔
364
        _check_assertions(wrapped, function_locals)
20✔
365

366
    # Check return type
367
    r = wrapped(*args, **kwargs)
20✔
368
    if return_type_enabled and "return" in annotations:
20✔
369
        return_type = annotations["return"]
20✔
370
        try:
20✔
371
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
20✔
372
            if STRICT_NUMERIC_TYPES:
20✔
373
                check_type_strict("return", r, return_type)
20✔
374
            else:
375
                check_type(r, return_type)
20✔
376
        except (TypeError, TypeCheckError):
20✔
377
            raise PyTAContractError(
20✔
378
                f"Return value {_display_value(r)} for {wrapped.__name__} did not match "
379
                f"expected type {_display_annotation(return_type)}"
380
            )
381

382
    # Check function postconditions
383
    if postconditions_enabled and not hasattr(target, "__postconditions__"):
20✔
384
        target.__postconditions__: list[tuple[str, CodeType, str]] = []
20✔
385
        return_val_var_name = _get_legal_return_val_var_name(
20✔
386
            {**wrapped.__globals__, **function_locals}
387
        )
388
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
20✔
389
        for postcondition in postconditions:
20✔
390
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
20✔
391
            try:
20✔
392
                compiled = compile(assertion, "<string>", "eval")
20✔
393
            except:
×
394
                _debug(
×
395
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
396
                )
397
                continue
×
398
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
20✔
399

400
    if ENABLE_CONTRACT_CHECKING and postconditions_enabled:
20✔
401
        _check_assertions(
20✔
402
            wrapped,
403
            function_locals,
404
            function_return_val=r,
405
            condition_type="postcondition",
406
        )
407

408
    return r
20✔
409

410

411
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
20✔
412
    """Ensure that `value` matches ``expected_type`` with strict type checking.
413

414
    This function enforces strict type distinctions within the numeric hierarchy (bool, int, float,
415
    complex), ensuring that the type of value is exactly the same as expected_type.
416
    """
417
    if not ENABLE_CONTRACT_CHECKING:
20✔
418
        return
20✔
419
    try:
20✔
420
        _check_inner_type(argname, value, expected_type)
20✔
421
    except (TypeError, TypeCheckError):
20✔
422
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
423

424

425
def _check_inner_type(argname: str, value: Any, expected_type: type) -> None:
20✔
426
    """Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
427
    collections (list[int], dicts[float]) and Union types (bool | int).
428
    """
429
    inner_types = get_args(expected_type)
20✔
430
    outer_type = get_origin(expected_type)
20✔
431
    if outer_type is None:
20✔
432
        if (
20✔
433
            (type(value) is bool and expected_type in {int, float, complex})
434
            or (type(value) is int and expected_type in {float, complex})
435
            or (type(value) is float and expected_type is complex)
436
        ):
437
            raise TypeError(
20✔
438
                f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
439
            )
440
        else:
441
            check_type(
20✔
442
                value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
443
            )
444
    elif outer_type is typing.Union:
20✔
445
        for inner_type in inner_types:
20✔
446
            try:
20✔
447
                _check_inner_type(argname, value, inner_type)
20✔
448
                return
20✔
449
            except (TypeError, TypeCheckError):
20✔
450
                pass
20✔
451
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
452
    elif outer_type in {list, set}:
20✔
453
        if isinstance(value, outer_type):
20✔
454
            for item in value:
20✔
455
                _check_inner_type(argname, item, inner_types[0])
20✔
456
        else:
457
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
458
    elif outer_type is dict:
20✔
459
        if isinstance(value, dict):
20✔
460
            for key, item in value.items():
20✔
461
                _check_inner_type(argname, key, inner_types[0])
20✔
462
                _check_inner_type(argname, item, inner_types[1])
20✔
463
        else:
464
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
465
    elif outer_type is tuple:
20✔
466
        if isinstance(value, tuple) and len(inner_types) == 2 and inner_types[1] is Ellipsis:
20✔
467
            for item in value:
20✔
468
                _check_inner_type(argname, item, inner_types[0])
20✔
469
        elif isinstance(value, tuple) and len(value) == len(inner_types):
20✔
470
            for item, inner_type in zip(value, inner_types):
20✔
471
                _check_inner_type(argname, item, inner_type)
20✔
472
        else:
473
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
20✔
474
    else:
475
        check_type(
20✔
476
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
477
        )
478

479

480
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
20✔
481
    """Returns potential suggestions for the given arg and its annotation"""
482
    try:
20✔
483
        if isinstance(arg, type) and issubclass(arg, annotation):
20✔
484
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
20✔
485
    except TypeError:
20✔
486
        pass
20✔
487

488
    return ""
20✔
489

490

491
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
20✔
492
    @wrapt.decorator
20✔
493
    def wrapper(wrapped, instance, args, kwargs):
20✔
494
        # Create an accumulator to store the instances mutated across this function call.
495
        # Store and restore existing mutated instance lists in case the instance method
496
        # executes another instance method.
497
        instance_klass = type(instance)
20✔
498
        mutated_instances_to_restore = None
20✔
499
        if hasattr(instance_klass, "__mutated_instances__"):
20✔
500
            mutated_instances_to_restore = getattr(instance_klass, "__mutated_instances__")
20✔
501
        setattr(instance_klass, "__mutated_instances__", [])
20✔
502

503
        try:
20✔
504
            r = _check_function_contracts(wrapped, instance, args, kwargs)
20✔
505
            if _instance_init_in_callstack(instance):
20✔
506
                return r
20✔
507
            _check_class_type_annotations(klass, instance)
20✔
508
            klass_mod = _get_module(klass)
20✔
509
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
20✔
510
                _check_invariants(instance, klass, klass_mod.__dict__)
20✔
511

512
                # Additionally check RI violations on PyTA-decorated instances that were mutated
513
                # across the function call.
514
                mutated_instances = getattr(instance_klass, "__mutated_instances__", [])
20✔
515
                for mutated_instance in mutated_instances:
20✔
516
                    # Mutated instances may be of parent class types so the invariants to check should also be
517
                    # for the parent class and not the child class.
518
                    mutated_instance_klass = type(mutated_instance)
20✔
519
                    mutated_instance_klass_mod = _get_module(mutated_instance_klass)
20✔
520
                    _check_invariants(
20✔
521
                        mutated_instance,
522
                        mutated_instance_klass,
523
                        mutated_instance_klass_mod.__dict__,
524
                    )
525
        except PyTAContractError as e:
20✔
526
            raise AssertionError(str(e)) from None
20✔
527
        else:
528
            return r
20✔
529
        finally:
530
            if mutated_instances_to_restore is None:
20✔
531
                delattr(instance_klass, "__mutated_instances__")
20✔
532
            else:
533
                setattr(instance_klass, "__mutated_instances__", mutated_instances_to_restore)
20✔
534

535
    return wrapper(wrapped)
20✔
536

537

538
def _instance_init_in_callstack(instance: Any) -> bool:
20✔
539
    """Return whether instance's init is part of the current callstack
540

541
    Note: due to the nature of the check, externally defined __init__ functions with
542
    'self' defined as the first parameter may pass this check.
543
    """
544
    frame = inspect.currentframe().f_back
20✔
545
    while frame:
20✔
546
        frame_context_name = inspect.getframeinfo(frame).function
20✔
547
        frame_context_self = frame.f_locals.get("self")
20✔
548
        frame_context_vars = frame.f_code.co_varnames
20✔
549
        if (
20✔
550
            frame_context_name == "__init__"
551
            and frame_context_self is instance
552
            and frame_context_vars[0] == "self"
553
        ):
554
            return True
20✔
555
        frame = frame.f_back
20✔
556
    return False
20✔
557

558

559
def _check_class_type_annotations(klass: type, instance: Any) -> None:
20✔
560
    """Check that the type annotations for the class still hold.
561

562
    Precondition:
563
        - isinstance(instance, klass)
564
    """
565
    klass_mod = _get_module(klass)
20✔
566
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
20✔
567

568
    for attr, annotation in cls_annotations.items():
20✔
569
        value = getattr(instance, attr)
20✔
570
        try:
20✔
571
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
20✔
572
            check_type(
20✔
573
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
574
            )
575
        except TypeCheckError:
20✔
576
            raise AssertionError(
20✔
577
                f"Value {_display_value(value)} for attribute {attr} did not match expected type "
578
                f"{_display_annotation(annotation)}"
579
            )
580

581

582
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
20✔
583
    """Check that the representation invariants for the instance are satisfied."""
584
    if hasattr(instance, "__pyta_currently_checking"):
20✔
585
        # If already checking invariants for this instance, skip to avoid infinite recursion
586
        return
20✔
587

588
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
20✔
589

590
    rep_invariants = getattr(klass, "__representation_invariants__", set())
20✔
591

592
    try:
20✔
593
        for invariant, compiled in rep_invariants:
20✔
594
            try:
20✔
595
                _debug(
20✔
596
                    "Checking representation invariant for "
597
                    f"{instance.__class__.__qualname__}: {invariant}"
598
                )
599
                check = eval(compiled, {**global_scope, "self": instance})
20✔
600
            except AssertionError as e:
20✔
601
                raise AssertionError(str(e)) from None
20✔
602
            except NameError as e:
20✔
603
                # Get the missing name
604
                missing = getattr(e, "name", None)
20✔
605
                if missing is None:
20✔
606
                    # Failsafe for version 3.9
UNCOV
607
                    message = re.search(r"name '(.+?)' is not defined", str(e))
×
UNCOV
608
                    if message:
×
UNCOV
609
                        missing = message.group(1)
×
610

611
                # Check if missing name is an attribute
612
                if missing is not None and hasattr(instance, missing):
20✔
613
                    print(
20✔
614
                        f"[WARNING] Could not find variable `{missing}` when evaluating representation invariant. Did you mean `self.{missing}`?",
615
                        file=sys.stderr,
616
                    )
617
                else:
618
                    _debug(f"Warning: could not evaluate representation invariant: {invariant}")
20✔
619
            except:
×
620
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
621
            else:
622
                if not check:
20✔
623
                    curr_attributes = ", ".join(
20✔
624
                        f"{k}: {_display_value(v)}"
625
                        for k, v in vars(instance).items()
626
                        if k != "__pyta_currently_checking"
627
                    )
628

629
                    curr_attributes = "{" + curr_attributes + "}"
20✔
630

631
                    raise PyTAContractError(
20✔
632
                        f'{instance.__class__.__name__} representation invariant "{invariant}" was violated for'
633
                        f" instance attributes {curr_attributes}"
634
                    )
635

636
    finally:
637
        delattr(instance, "__pyta_currently_checking")
20✔
638

639

640
def _get_legal_return_val_var_name(var_dict: dict) -> str:
20✔
641
    """
642
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
643
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
644
    postconditions.
645
    """
646
    legal_var_name = "__function_return_value__"
20✔
647

648
    while legal_var_name in var_dict:
20✔
649
        legal_var_name += "_"
×
650

651
    return legal_var_name
20✔
652

653

654
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
20✔
655
    """
656
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
657
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
658

659
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
660
    """
661

662
    if FUNCTION_RETURN_VALUE in assertion:
20✔
663
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
20✔
664
    return assertion
20✔
665

666

667
def _check_assertions(
20✔
668
    wrapped: Callable[..., Any],
669
    function_locals: dict,
670
    condition_type: str = "precondition",
671
    function_return_val: Any = None,
672
) -> None:
673
    """Check that the given assertions are still satisfied."""
674
    # Check bounded function
675
    if hasattr(wrapped, "__self__"):
20✔
676
        target = wrapped.__func__
20✔
677
    else:
678
        target = wrapped
20✔
679
    assertions = []
20✔
680
    if condition_type == "precondition":
20✔
681
        assertions = target.__preconditions__
20✔
682
    elif condition_type == "postcondition":
20✔
683
        assertions = target.__postconditions__
20✔
684
    for assertion_str, compiled, *return_val_var_name in assertions:
20✔
685
        return_val_dict = {}
20✔
686
        if condition_type == "postcondition":
20✔
687
            return_val_dict = {return_val_var_name[0]: function_return_val}
20✔
688
        try:
20✔
689
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
20✔
690
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
20✔
691
        except AssertionError as e:
20✔
692
            raise AssertionError(str(e)) from None
20✔
693
        except:
×
694
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
695
        else:
696
            if not check:
20✔
697
                arg_string = ", ".join(
20✔
698
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
699
                )
700
                arg_string = "{" + arg_string + "}"
20✔
701

702
                return_val_string = ""
20✔
703

704
                if condition_type == "postcondition":
20✔
705
                    return_val_string = f" and return value {function_return_val}"
20✔
706
                raise PyTAContractError(
20✔
707
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
708
                    f"violated for arguments {arg_string}{return_val_string}"
709
                )
710

711

712
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> list[str]:
20✔
713
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
714

715
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
716

717
    Currently only supports two forms:
718

719
    1. A single line of the form "<parse_token>: <cond>"
720
    2. A group of lines starting with "<parse_token>s:", where each subsequent
721
       line is of the form "- <cond>". Each line is considered a separate condition.
722
       The lines can be separated by blank lines, but no other text.
723
    """
724
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
20✔
725
        # Check if obj is an astroid node
726
        docstring = obj.doc_node.value
20✔
727
    else:
728
        docstring = getattr(obj, "__doc__") or ""
20✔
729
    lines = [line.strip() for line in docstring.split("\n")]
20✔
730
    assertion_lines = [
20✔
731
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
732
    ]
733

734
    if assertion_lines == []:
20✔
735
        return []
20✔
736

737
    first = assertion_lines[0]
20✔
738

739
    if lines[first].startswith(parse_token + ":"):
20✔
740
        return [lines[first][len(parse_token + ":") :].strip()]
20✔
741
    elif lines[first].startswith(parse_token + "s:"):
20✔
742
        assertions = []
20✔
743
        for line in lines[first + 1 :]:
20✔
744
            if line.startswith("-"):
20✔
745
                assertion = line[1:].strip()
20✔
746
                if hasattr(obj, "__qualname__"):
20✔
747
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
20✔
748
                assertions.append(assertion)
20✔
749
            elif line != "":
20✔
750
                break
×
751
        return assertions
16✔
752
    else:
753
        return []
×
754

755

756
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
20✔
757
    """Return a human-friendly representation of the given value.
758

759
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
760

761
    Preconditions:
762
        - max_length >= 5
763
    """
764
    s = repr(value)
20✔
765
    if not DEBUG_CONTRACTS and len(s) > max_length:
20✔
766
        i = (max_length - 3) // 2
×
767
        return s[:i] + "..." + s[-i:]
×
768
    else:
769
        return s
20✔
770

771

772
def _display_annotation(annotation: Any) -> str:
20✔
773
    """Return a human-friendly representation of the given type annotation.
774

775
    >>> _display_annotation(int)
776
    'int'
777
    >>> _display_annotation(list[int])
778
    'list[int]'
779
    >>> from typing import List
780
    >>> _display_annotation(List[int])
781
    'typing.List[int]'
782
    """
783
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
20✔
784
        return "None"
×
785
    if hasattr(annotation, "__origin__"):  # Generic type annotations
20✔
786
        return repr(annotation)
20✔
787
    elif hasattr(annotation, "__name__"):
20✔
788
        return annotation.__name__
20✔
789
    else:
790
        return repr(annotation)
×
791

792

793
def _get_module(obj: Any) -> ModuleType:
20✔
794
    """Return the module where obj was defined (normally obj.__module__).
795

796
    NOTE: this function defines a special case when using PyCharm and the file
797
    defining the object is "Run in Python Console". In this case, the pydevd runner
798
    renames the '__main__' module to 'pydev_umd', and so we need to access that
799
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
800
    to False.
801
    """
802
    module_name = obj.__module__
20✔
803
    module = sys.modules[module_name]
20✔
804

805
    if (
20✔
806
        module_name != "__main__"
807
        or not RENAME_MAIN_TO_PYDEV_UMD
808
        or _PYDEV_UMD_NAME not in sys.modules
809
    ):
810
        return module
20✔
811

812
    # Get a function/class name to check whether it is defined in the module
813
    if isinstance(obj, (FunctionType, type)):
×
814
        name = obj.__name__
×
815
    else:
816
        # For any other type of object, be conservative and just return the module
817
        return module
×
818

819
    if name in vars(module):
×
820
        return module
×
821
    else:
822
        return sys.modules[_PYDEV_UMD_NAME]
×
823

824

825
def _debug(msg: str) -> None:
20✔
826
    """Display a debugging message.
827

828
    Do nothing if DEBUG_CONTRACTS is False.
829
    """
830
    if not DEBUG_CONTRACTS:
20✔
831
        return
20✔
832
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
20✔
833
    logging.debug(msg)
20✔
834

835

836
def _set_invariants(klass: type) -> None:
20✔
837
    """Retrieve and set the representation invariants of this class"""
838
    # Update representation invariants from this class' docstring and those of its superclasses.
839
    rep_invariants: list[tuple[str, CodeType]] = []
20✔
840

841
    # Iterate over all inherited classes except builtins
842
    for cls in reversed(klass.__mro__):
20✔
843
        if "__representation_invariants__" in cls.__dict__:
20✔
844
            rep_invariants.extend(cls.__representation_invariants__)
20✔
845
        elif cls.__module__ != "builtins":
20✔
846
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
20✔
847
            # Try compiling assertions
848
            for assertion in assertions:
20✔
849
                try:
20✔
850
                    compiled = compile(assertion, "<string>", "eval")
20✔
851
                except:
×
852
                    _debug(
×
853
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
854
                    )
855
                    continue
×
856
                rep_invariants.append((assertion, compiled))
20✔
857

858
    setattr(klass, "__representation_invariants__", rep_invariants)
20✔
859

860

861
def validate_invariants(obj: object) -> None:
20✔
862
    """Check that the representation invariants of obj are satisfied."""
863
    klass = obj.__class__
20✔
864
    klass_mod = _get_module(klass)
20✔
865

866
    try:
20✔
867
        _check_invariants(obj, klass, klass_mod.__dict__)
20✔
868
    except PyTAContractError as e:
20✔
869
        raise AssertionError(str(e)) from None
20✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc