• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 11657005865

04 Nov 2024 02:25AM UTC coverage: 91.955% (+0.06%) from 91.895%
11657005865

Pull #1103

github

web-flow
Merge e7780cda3 into 5d50aa36b
Pull Request #1103: Enable Strict Type Checking for Complex Data Types in check_contracts

40 of 40 new or added lines in 1 file covered. (100.0%)

1 existing line in 1 file now uncovered.

3086 of 3356 relevant lines covered (91.95%)

9.18 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.91
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
from __future__ import annotations
10✔
14

15
import inspect
10✔
16
import logging
10✔
17
import sys
10✔
18
import typing
10✔
19
from types import CodeType, FunctionType, ModuleType
10✔
20
from typing import (
10✔
21
    Any,
22
    Callable,
23
    Collection,
24
    Optional,
25
    TypeVar,
26
    Union,
27
    get_args,
28
    get_origin,
29
    overload,
30
)
31

32
import wrapt
10✔
33
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
10✔
34

35
# Configuration options
36

37
ENABLE_CONTRACT_CHECKING = True
10✔
38
"""
8✔
39
Set to True to enable contract checking.
40
"""
41

42
DEBUG_CONTRACTS = False
10✔
43
"""
8✔
44
Set to True to display debugging messages when checking contracts.
45
"""
46

47
RENAME_MAIN_TO_PYDEV_UMD = True
10✔
48
"""
8✔
49
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
50
In most cases you should not need to change this!
51
"""
52

53
STRICT_NUMERIC_TYPES = True
10✔
54
"""
8✔
55
Set to False to allow more specific numeric types to be accepted by more general type annotations.
56
"""
57

58
_PYDEV_UMD_NAME = "pydev_umd"
10✔
59

60

61
_DEFAULT_MAX_VALUE_LENGTH = 30
10✔
62
FUNCTION_RETURN_VALUE = "$return_value"
10✔
63

64

65
class PyTAContractError(Exception):
10✔
66
    """Error raised when a PyTA contract assertion is violated."""
67

68

69
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
10✔
70
    """Automatically check contracts for all functions and classes in the given modules.
71

72
    By default (when called with no arguments), the current module is used.
73

74
    Args:
75
        *mod_names: The names of modules to check contracts for. These modules must have been
76
            previously imported.
77
        decorate_main: True if the module being run (where __name__ == '__main__') should
78
            have contracts checked.
79
    """
80
    if not ENABLE_CONTRACT_CHECKING:
10✔
81
        return
×
82

83
    modules = []
10✔
84
    if decorate_main:
10✔
85
        mod_names = mod_names + ("__main__",)
×
86

87
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
88
        # with the "Run in Python Console" action.
89
        if RENAME_MAIN_TO_PYDEV_UMD:
×
90
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
91

92
    for module_name in mod_names:
10✔
93
        modules.append(sys.modules.get(module_name, None))
10✔
94

95
    for module in modules:
10✔
96
        if not module:
10✔
97
            # Module name was passed in incorrectly.
98
            continue
×
99
        for name, value in inspect.getmembers(module):
10✔
100
            if inspect.isfunction(value) or inspect.isclass(value):
10✔
101
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
10✔
102

103

104
@wrapt.decorator
10✔
105
def _enable_function_contracts(wrapped, instance, args, kwargs):
10✔
106
    """A decorator that enables checking contracts for a function."""
107
    try:
10✔
108
        if instance is not None and inspect.isclass(instance):
10✔
109
            # This is a class method, so there is no instance.
110
            return _check_function_contracts(wrapped, None, args, kwargs)
10✔
111
        else:
112
            return _check_function_contracts(wrapped, instance, args, kwargs)
10✔
113
    except PyTAContractError as e:
10✔
114
        raise AssertionError(str(e)) from None
10✔
115

116

117
# Wildcard Type Variable
118
Class = TypeVar("Class", bound=type)
10✔
119

120

121
@overload
10✔
122
def check_contracts(
10✔
123
    func: FunctionType, module_names: Optional[set[str]] = None
124
) -> FunctionType: ...
125

126

127
@overload
10✔
128
def check_contracts(func: Class, module_names: Optional[set[str]] = None) -> Class: ...
10✔
129

130

131
def check_contracts(
10✔
132
    func_or_class: Union[Class, FunctionType], module_names: Optional[set[str]] = None
133
) -> Union[Class, FunctionType]:
134
    """A decorator to enable contract checking for a function or class.
135

136
    When used with a class, all methods defined within the class have contract checking enabled.
137
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
138

139
    Example:
140

141
        >>> from python_ta.contracts import check_contracts
142
        >>> @check_contracts
143
        ... def divide(x: int, y: int) -> int:
144
        ...     \"\"\"Return x // y.
145
        ...
146
        ...     Preconditions:
147
        ...        - y != 0
148
        ...     \"\"\"
149
        ...     return x // y
150
    """
151
    if not ENABLE_CONTRACT_CHECKING:
10✔
152
        return func_or_class
10✔
153

154
    if module_names is not None and func_or_class.__module__ not in module_names:
10✔
155
        _debug(
10✔
156
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
157
        )
158
        return func_or_class
10✔
159
    elif inspect.isroutine(func_or_class):
10✔
160
        return _enable_function_contracts(func_or_class)
10✔
161
    elif inspect.isclass(func_or_class):
10✔
162
        add_class_invariants(func_or_class)
10✔
163
        return func_or_class
10✔
164
    else:
165
        # Default action
166
        return func_or_class
×
167

168

169
def add_class_invariants(klass: type) -> None:
10✔
170
    """Modify the given class to check representation invariants and method contracts."""
171
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
10✔
172
        # This means the class has already been decorated
173
        return
×
174

175
    _set_invariants(klass)
10✔
176

177
    klass_mod = _get_module(klass)
10✔
178
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
10✔
179

180
    def new_setattr(self: klass, name: str, value: Any) -> None:
10✔
181
        """Set the value of the given attribute on self to the given value.
182

183
        Check representation invariants for this class when not within an instance method of the class.
184
        """
185
        if not ENABLE_CONTRACT_CHECKING:
10✔
186
            super(klass, self).__setattr__(name, value)
10✔
187
            return
10✔
188

189
        nonlocal cls_annotations
190
        if cls_annotations is None:
10✔
191
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
10✔
192

193
        if name in cls_annotations:
10✔
194
            try:
10✔
195
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
10✔
196
                check_type(
10✔
197
                    value,
198
                    cls_annotations[name],
199
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
200
                )
201
            except TypeCheckError:
10✔
202
                raise AssertionError(
10✔
203
                    f"Value {_display_value(value)} did not match type annotation for attribute "
204
                    f"{name}: {_display_annotation(cls_annotations[name])}"
205
                ) from None
206
        original_attr_value_exists = False
10✔
207
        original_attr_value = None
10✔
208
        if hasattr(super(klass, self), name):
10✔
209
            original_attr_value_exists = True
×
210
            original_attr_value = super(klass, self).__getattribute__(name)
×
211
        super(klass, self).__setattr__(name, value)
10✔
212
        frame_locals = inspect.currentframe().f_back.f_locals
10✔
213
        if self is not frame_locals.get("self"):
10✔
214
            # Only validating if the attribute is not being set in a instance/class method
215
            if klass_mod is not None:
10✔
216
                try:
10✔
217
                    _check_invariants(self, klass, klass_mod.__dict__)
10✔
218
                except PyTAContractError as e:
10✔
219
                    if original_attr_value_exists:
10✔
220
                        super(klass, self).__setattr__(name, original_attr_value)
×
221
                    else:
222
                        super(klass, self).__delattr__(name)
10✔
223
                    raise AssertionError(str(e)) from None
10✔
224

225
    for attr, value in klass.__dict__.items():
10✔
226
        if inspect.isroutine(value):
10✔
227
            if isinstance(value, (staticmethod, classmethod)):
10✔
228
                # Don't check rep invariants for staticmethod and classmethod
229
                setattr(klass, attr, check_contracts(value))
10✔
230
            else:
231
                setattr(klass, attr, _instance_method_wrapper(value, klass))
10✔
232

233
    klass.__setattr__ = new_setattr
10✔
234

235

236
def _check_function_contracts(wrapped, instance, args, kwargs):
10✔
237
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
10✔
238
    if instance is not None:
10✔
239
        klass_mod = _get_module(type(instance))
10✔
240
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
10✔
241
    else:
242
        annotations = typing.get_type_hints(wrapped)
10✔
243
    args_with_self = args if instance is None else (instance,) + args
10✔
244

245
    # Check function parameter types
246
    for arg, param in zip(args_with_self, params):
10✔
247
        if param in annotations:
10✔
248
            try:
10✔
249
                _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
10✔
250
                if STRICT_NUMERIC_TYPES:
10✔
251
                    check_type_strict(param, arg, annotations[param])
10✔
252
                else:
253
                    check_type(arg, annotations[param])
10✔
254
            except (TypeError, TypeCheckError):
10✔
255
                additional_suggestions = _get_argument_suggestions(arg, annotations[param])
10✔
256

257
                raise PyTAContractError(
10✔
258
                    f"{wrapped.__name__} argument {_display_value(arg)} did not match type "
259
                    f"annotation for parameter {param}: {_display_annotation(annotations[param])}"
260
                    + (f"\n{additional_suggestions}" if additional_suggestions else "")
261
                )
262

263
    function_locals = dict(zip(params, args_with_self))
10✔
264

265
    # Check bounded function
266
    if hasattr(wrapped, "__self__"):
10✔
267
        target = wrapped.__func__
10✔
268
    else:
269
        target = wrapped
10✔
270

271
    # Check function preconditions
272
    if not hasattr(target, "__preconditions__"):
10✔
273
        target.__preconditions__: list[tuple[str, CodeType]] = []
10✔
274
        preconditions = parse_assertions(wrapped)
10✔
275
        for precondition in preconditions:
10✔
276
            try:
10✔
277
                compiled = compile(precondition, "<string>", "eval")
10✔
278
            except:
10✔
279
                _debug(
10✔
280
                    f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
281
                )
282
                continue
10✔
283
            target.__preconditions__.append((precondition, compiled))
10✔
284

285
    if ENABLE_CONTRACT_CHECKING:
10✔
286
        _check_assertions(wrapped, function_locals)
10✔
287

288
    # Check return type
289
    r = wrapped(*args, **kwargs)
10✔
290
    if "return" in annotations:
10✔
291
        return_type = annotations["return"]
10✔
292
        try:
10✔
293
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
10✔
294
            if STRICT_NUMERIC_TYPES:
10✔
295
                check_type_strict("return", r, return_type)
10✔
296
            else:
297
                check_type(r, return_type)
10✔
298
        except (TypeError, TypeCheckError):
10✔
299
            raise PyTAContractError(
10✔
300
                f"{wrapped.__name__}'s return value {_display_value(r)} did not match "
301
                f"return type annotation {_display_annotation(return_type)}"
302
            )
303

304
    # Check function postconditions
305
    if not hasattr(target, "__postconditions__"):
10✔
306
        target.__postconditions__: list[tuple[str, CodeType, str]] = []
10✔
307
        return_val_var_name = _get_legal_return_val_var_name(
10✔
308
            {**wrapped.__globals__, **function_locals}
309
        )
310
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
10✔
311
        for postcondition in postconditions:
10✔
312
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
10✔
313
            try:
10✔
314
                compiled = compile(assertion, "<string>", "eval")
10✔
315
            except:
×
316
                _debug(
×
317
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
318
                )
319
                continue
×
320
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
10✔
321

322
    if ENABLE_CONTRACT_CHECKING:
10✔
323
        _check_assertions(
10✔
324
            wrapped,
325
            function_locals,
326
            function_return_val=r,
327
            condition_type="postcondition",
328
        )
329

330
    return r
10✔
331

332

333
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
10✔
334
    """Ensure that `value` matches ``expected_type`` with strict type checking.
335

336
    This function enforces strict type distinctions within the numeric hierarchy (bool, int, float,
337
    complex), ensuring that the type of value is exactly the same as expected_type.
338
    """
339
    if not ENABLE_CONTRACT_CHECKING:
10✔
340
        return
10✔
341
    try:
10✔
342
        _check_inner_type(argname, value, expected_type)
10✔
343
    except (TypeError, TypeCheckError):
10✔
344
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
345

346

347
def _check_inner_type(argname: str, value: Any, expected_type: type) -> None:
10✔
348
    """Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
349
    collections (list[int], dicts[float]) and Union types (bool | int).
350
    """
351
    inner_types = get_args(expected_type)
10✔
352
    outer_type = get_origin(expected_type)
10✔
353
    if outer_type is None:
10✔
354
        if (
10✔
355
            (type(value) is bool and expected_type in {int, float, complex})
356
            or (type(value) is int and expected_type in {float, complex})
357
            or (type(value) is float and expected_type is complex)
358
        ):
359
            raise TypeError(
10✔
360
                f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
361
            )
362
        else:
363
            check_type(
10✔
364
                value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
365
            )
366
    elif outer_type is typing.Union:
10✔
367
        for inner_type in inner_types:
10✔
368
            try:
10✔
369
                _check_inner_type(argname, value, inner_type)
10✔
370
                return
10✔
371
            except (TypeError, TypeCheckError):
10✔
372
                pass
10✔
373
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
374
    elif outer_type in {list, set}:
10✔
375
        if isinstance(value, outer_type):
10✔
376
            for item in value:
10✔
377
                _check_inner_type(argname, item, inner_types[0])
10✔
378
        else:
379
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
380
    elif outer_type is dict:
10✔
381
        if isinstance(value, dict):
10✔
382
            for key, item in value.items():
10✔
383
                _check_inner_type(argname, key, inner_types[0])
10✔
384
                _check_inner_type(argname, item, inner_types[1])
10✔
385
        else:
386
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
387
    elif outer_type is tuple:
10✔
388
        if isinstance(value, tuple) and len(inner_types) == 2 and inner_types[1] is Ellipsis:
10✔
389
            for item in value:
10✔
390
                _check_inner_type(argname, item, inner_types[0])
10✔
391
        elif isinstance(value, tuple) and len(value) == len(inner_types):
10✔
392
            for item, inner_type in zip(value, inner_types):
10✔
393
                _check_inner_type(argname, item, inner_type)
10✔
394
        else:
395
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
396
    else:
UNCOV
397
        check_type(
×
398
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
399
        )
400

401

402
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
10✔
403
    """Returns potential suggestions for the given arg and its annotation"""
404
    try:
10✔
405
        if isinstance(arg, type) and issubclass(arg, annotation):
10✔
406
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
10✔
407
    except TypeError:
10✔
408
        pass
10✔
409

410
    return ""
10✔
411

412

413
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
10✔
414
    @wrapt.decorator
10✔
415
    def wrapper(wrapped, instance, args, kwargs):
10✔
416
        try:
10✔
417
            r = _check_function_contracts(wrapped, instance, args, kwargs)
10✔
418
            if _instance_init_in_callstack(instance):
10✔
419
                return r
10✔
420
            _check_class_type_annotations(klass, instance)
10✔
421
            klass_mod = _get_module(klass)
10✔
422
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
10✔
423
                _check_invariants(instance, klass, klass_mod.__dict__)
10✔
424
        except PyTAContractError as e:
10✔
425
            raise AssertionError(str(e)) from None
10✔
426
        else:
427
            return r
10✔
428

429
    return wrapper(wrapped)
10✔
430

431

432
def _instance_init_in_callstack(instance: Any) -> bool:
10✔
433
    """Return whether instance's init is part of the current callstack
434

435
    Note: due to the nature of the check, externally defined __init__ functions with
436
    'self' defined as the first parameter may pass this check.
437
    """
438
    frame = inspect.currentframe().f_back
10✔
439
    while frame:
10✔
440
        frame_context_name = inspect.getframeinfo(frame).function
10✔
441
        frame_context_self = frame.f_locals.get("self")
10✔
442
        frame_context_vars = frame.f_code.co_varnames
10✔
443
        if (
10✔
444
            frame_context_name == "__init__"
445
            and frame_context_self is instance
446
            and frame_context_vars[0] == "self"
447
        ):
448
            return True
10✔
449
        frame = frame.f_back
10✔
450
    return False
10✔
451

452

453
def _check_class_type_annotations(klass: type, instance: Any) -> None:
10✔
454
    """Check that the type annotations for the class still hold.
455

456
    Precondition:
457
        - isinstance(instance, klass)
458
    """
459
    klass_mod = _get_module(klass)
10✔
460
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
10✔
461

462
    for attr, annotation in cls_annotations.items():
10✔
463
        value = getattr(instance, attr)
10✔
464
        try:
10✔
465
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
10✔
466
            check_type(
10✔
467
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
468
            )
469
        except TypeCheckError:
10✔
470
            raise AssertionError(
10✔
471
                f"{_display_value(value)} did not match type annotation for attribute {attr}: "
472
                f"{_display_annotation(annotation)}"
473
            )
474

475

476
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
10✔
477
    """Check that the representation invariants for the instance are satisfied."""
478
    if hasattr(instance, "__pyta_currently_checking"):
10✔
479
        # If already checking invariants for this instance, skip to avoid infinite recursion
480
        return
10✔
481

482
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
10✔
483

484
    rep_invariants = getattr(klass, "__representation_invariants__", set())
10✔
485

486
    try:
10✔
487
        for invariant, compiled in rep_invariants:
10✔
488
            try:
10✔
489
                _debug(
10✔
490
                    "Checking representation invariant for "
491
                    f"{instance.__class__.__qualname__}: {invariant}"
492
                )
493
                check = eval(compiled, {**global_scope, "self": instance})
10✔
494
            except AssertionError as e:
10✔
495
                raise AssertionError(str(e)) from None
10✔
496
            except:
×
497
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
498
            else:
499
                if not check:
10✔
500
                    curr_attributes = ", ".join(
10✔
501
                        f"{k}: {_display_value(v)}"
502
                        for k, v in vars(instance).items()
503
                        if k != "__pyta_currently_checking"
504
                    )
505

506
                    curr_attributes = "{" + curr_attributes + "}"
10✔
507

508
                    raise PyTAContractError(
10✔
509
                        f'"{instance.__class__.__name__}" representation invariant "{invariant}" was violated for'
510
                        f" instance attributes {curr_attributes}"
511
                    )
512

513
    finally:
514
        delattr(instance, "__pyta_currently_checking")
10✔
515

516

517
def _get_legal_return_val_var_name(var_dict: dict) -> str:
10✔
518
    """
519
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
520
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
521
    postconditions.
522
    """
523
    legal_var_name = "__function_return_value__"
10✔
524

525
    while legal_var_name in var_dict:
10✔
526
        legal_var_name += "_"
×
527

528
    return legal_var_name
10✔
529

530

531
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
10✔
532
    """
533
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
534
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
535

536
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
537
    """
538

539
    if FUNCTION_RETURN_VALUE in assertion:
10✔
540
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
10✔
541
    return assertion
10✔
542

543

544
def _check_assertions(
10✔
545
    wrapped: Callable[..., Any],
546
    function_locals: dict,
547
    condition_type: str = "precondition",
548
    function_return_val: Any = None,
549
) -> None:
550
    """Check that the given assertions are still satisfied."""
551
    # Check bounded function
552
    if hasattr(wrapped, "__self__"):
10✔
553
        target = wrapped.__func__
10✔
554
    else:
555
        target = wrapped
10✔
556
    assertions = []
10✔
557
    if condition_type == "precondition":
10✔
558
        assertions = target.__preconditions__
10✔
559
    elif condition_type == "postcondition":
10✔
560
        assertions = target.__postconditions__
10✔
561
    for assertion_str, compiled, *return_val_var_name in assertions:
10✔
562
        return_val_dict = {}
10✔
563
        if condition_type == "postcondition":
10✔
564
            return_val_dict = {return_val_var_name[0]: function_return_val}
10✔
565
        try:
10✔
566
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
10✔
567
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
10✔
568
        except AssertionError as e:
10✔
569
            raise AssertionError(str(e)) from None
10✔
570
        except:
×
571
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
572
        else:
573
            if not check:
10✔
574
                arg_string = ", ".join(
10✔
575
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
576
                )
577
                arg_string = "{" + arg_string + "}"
10✔
578

579
                return_val_string = ""
10✔
580

581
                if condition_type == "postcondition":
10✔
582
                    return_val_string = f"and return value {function_return_val}"
10✔
583
                raise PyTAContractError(
10✔
584
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
585
                    f"violated for arguments {arg_string} {return_val_string}"
586
                )
587

588

589
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> list[str]:
10✔
590
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
591

592
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
593

594
    Currently only supports two forms:
595

596
    1. A single line of the form "<parse_token>: <cond>"
597
    2. A group of lines starting with "<parse_token>s:", where each subsequent
598
       line is of the form "- <cond>". Each line is considered a separate condition.
599
       The lines can be separated by blank lines, but no other text.
600
    """
601
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
10✔
602
        # Check if obj is an astroid node
603
        docstring = obj.doc_node.value
10✔
604
    else:
605
        docstring = getattr(obj, "__doc__") or ""
10✔
606
    lines = [line.strip() for line in docstring.split("\n")]
10✔
607
    assertion_lines = [
10✔
608
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
609
    ]
610

611
    if assertion_lines == []:
10✔
612
        return []
10✔
613

614
    first = assertion_lines[0]
10✔
615

616
    if lines[first].startswith(parse_token + ":"):
10✔
617
        return [lines[first][len(parse_token + ":") :].strip()]
10✔
618
    elif lines[first].startswith(parse_token + "s:"):
10✔
619
        assertions = []
10✔
620
        for line in lines[first + 1 :]:
10✔
621
            if line.startswith("-"):
10✔
622
                assertion = line[1:].strip()
10✔
623
                if hasattr(obj, "__qualname__"):
10✔
624
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
10✔
625
                assertions.append(assertion)
10✔
626
            elif line != "":
10✔
627
                break
×
628
        return assertions
10✔
629
    else:
630
        return []
×
631

632

633
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
10✔
634
    """Return a human-friendly representation of the given value.
635

636
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
637

638
    Preconditions:
639
        - max_length >= 5
640
    """
641
    s = repr(value)
10✔
642
    if not DEBUG_CONTRACTS and len(s) > max_length:
10✔
643
        i = (max_length - 3) // 2
×
644
        return s[:i] + "..." + s[-i:]
×
645
    else:
646
        return s
10✔
647

648

649
def _display_annotation(annotation: Any) -> str:
10✔
650
    """Return a human-friendly representation of the given type annotation.
651

652
    >>> _display_annotation(int)
653
    'int'
654
    >>> _display_annotation(list[int])
655
    'list[int]'
656
    >>> from typing import List
657
    >>> _display_annotation(List[int])
658
    'typing.List[int]'
659
    """
660
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
10✔
661
        return "None"
×
662
    if hasattr(annotation, "__origin__"):  # Generic type annotations
10✔
663
        return repr(annotation)
10✔
664
    elif hasattr(annotation, "__name__"):
10✔
665
        return annotation.__name__
10✔
666
    else:
667
        return repr(annotation)
×
668

669

670
def _get_module(obj: Any) -> ModuleType:
10✔
671
    """Return the module where obj was defined (normally obj.__module__).
672

673
    NOTE: this function defines a special case when using PyCharm and the file
674
    defining the object is "Run in Python Console". In this case, the pydevd runner
675
    renames the '__main__' module to 'pydev_umd', and so we need to access that
676
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
677
    to False.
678
    """
679
    module_name = obj.__module__
10✔
680
    module = sys.modules[module_name]
10✔
681

682
    if (
10✔
683
        module_name != "__main__"
684
        or not RENAME_MAIN_TO_PYDEV_UMD
685
        or _PYDEV_UMD_NAME not in sys.modules
686
    ):
687
        return module
10✔
688

689
    # Get a function/class name to check whether it is defined in the module
690
    if isinstance(obj, (FunctionType, type)):
×
691
        name = obj.__name__
×
692
    else:
693
        # For any other type of object, be conservative and just return the module
694
        return module
×
695

696
    if name in vars(module):
×
697
        return module
×
698
    else:
699
        return sys.modules[_PYDEV_UMD_NAME]
×
700

701

702
def _debug(msg: str) -> None:
10✔
703
    """Display a debugging message.
704

705
    Do nothing if DEBUG_CONTRACTS is False.
706
    """
707
    if not DEBUG_CONTRACTS:
10✔
708
        return
10✔
709
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
10✔
710
    logging.debug(msg)
10✔
711

712

713
def _set_invariants(klass: type) -> None:
10✔
714
    """Retrieve and set the representation invariants of this class"""
715
    # Update representation invariants from this class' docstring and those of its superclasses.
716
    rep_invariants: list[tuple[str, CodeType]] = []
10✔
717

718
    # Iterate over all inherited classes except builtins
719
    for cls in reversed(klass.__mro__):
10✔
720
        if "__representation_invariants__" in cls.__dict__:
10✔
721
            rep_invariants.extend(cls.__representation_invariants__)
10✔
722
        elif cls.__module__ != "builtins":
10✔
723
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
10✔
724
            # Try compiling assertions
725
            for assertion in assertions:
10✔
726
                try:
10✔
727
                    compiled = compile(assertion, "<string>", "eval")
10✔
728
                except:
×
729
                    _debug(
×
730
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
731
                    )
732
                    continue
×
733
                rep_invariants.append((assertion, compiled))
10✔
734

735
    setattr(klass, "__representation_invariants__", rep_invariants)
10✔
736

737

738
def validate_invariants(obj: object) -> None:
10✔
739
    """Check that the representation invariants of obj are satisfied."""
740
    klass = obj.__class__
10✔
741
    klass_mod = _get_module(klass)
10✔
742

743
    try:
10✔
744
        _check_invariants(obj, klass, klass_mod.__dict__)
10✔
745
    except PyTAContractError as e:
10✔
746
        raise AssertionError(str(e)) from None
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc