• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 11645860418

02 Nov 2024 09:13PM UTC coverage: 91.925% (+0.03%) from 91.895%
11645860418

Pull #1103

github

web-flow
Merge ccfdf06ab into 5d50aa36b
Pull Request #1103: Enable Strict Type Checking for Complex Data Types in check_contracts

39 of 40 new or added lines in 1 file covered. (97.5%)

1 existing line in 1 file now uncovered.

3085 of 3356 relevant lines covered (91.92%)

9.18 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.64
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
from __future__ import annotations
10✔
14

15
import inspect
10✔
16
import logging
10✔
17
import sys
10✔
18
import typing
10✔
19
from types import CodeType, FunctionType, ModuleType
10✔
20
from typing import (
10✔
21
    Any,
22
    Callable,
23
    Collection,
24
    Optional,
25
    TypeVar,
26
    Union,
27
    get_args,
28
    get_origin,
29
    overload,
30
)
31

32
import wrapt
10✔
33
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
10✔
34

35
# Configuration options
36

37
ENABLE_CONTRACT_CHECKING = True
10✔
38
"""
8✔
39
Set to True to enable contract checking.
40
"""
41

42
DEBUG_CONTRACTS = False
10✔
43
"""
8✔
44
Set to True to display debugging messages when checking contracts.
45
"""
46

47
RENAME_MAIN_TO_PYDEV_UMD = True
10✔
48
"""
8✔
49
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
50
In most cases you should not need to change this!
51
"""
52

53
STRICT_NUMERIC_TYPES = True
10✔
54
"""
8✔
55
Set to False to allow more specific numeric types to be accepted by more general type annotations.
56
"""
57

58
_PYDEV_UMD_NAME = "pydev_umd"
10✔
59

60

61
_DEFAULT_MAX_VALUE_LENGTH = 30
10✔
62
FUNCTION_RETURN_VALUE = "$return_value"
10✔
63

64

65
class PyTAContractError(Exception):
10✔
66
    """Error raised when a PyTA contract assertion is violated."""
67

68

69
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
10✔
70
    """Automatically check contracts for all functions and classes in the given modules.
71

72
    By default (when called with no arguments), the current module is used.
73

74
    Args:
75
        *mod_names: The names of modules to check contracts for. These modules must have been
76
            previously imported.
77
        decorate_main: True if the module being run (where __name__ == '__main__') should
78
            have contracts checked.
79
    """
80
    if not ENABLE_CONTRACT_CHECKING:
10✔
81
        return
×
82

83
    modules = []
10✔
84
    if decorate_main:
10✔
85
        mod_names = mod_names + ("__main__",)
×
86

87
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
88
        # with the "Run in Python Console" action.
89
        if RENAME_MAIN_TO_PYDEV_UMD:
×
90
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
91

92
    for module_name in mod_names:
10✔
93
        modules.append(sys.modules.get(module_name, None))
10✔
94

95
    for module in modules:
10✔
96
        if not module:
10✔
97
            # Module name was passed in incorrectly.
98
            continue
×
99
        for name, value in inspect.getmembers(module):
10✔
100
            if inspect.isfunction(value) or inspect.isclass(value):
10✔
101
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
10✔
102

103

104
@wrapt.decorator
10✔
105
def _enable_function_contracts(wrapped, instance, args, kwargs):
10✔
106
    """A decorator that enables checking contracts for a function."""
107
    try:
10✔
108
        if instance is not None and inspect.isclass(instance):
10✔
109
            # This is a class method, so there is no instance.
110
            return _check_function_contracts(wrapped, None, args, kwargs)
10✔
111
        else:
112
            return _check_function_contracts(wrapped, instance, args, kwargs)
10✔
113
    except PyTAContractError as e:
10✔
114
        raise AssertionError(str(e)) from None
10✔
115

116

117
# Wildcard Type Variable
118
Class = TypeVar("Class", bound=type)
10✔
119

120

121
@overload
10✔
122
def check_contracts(
10✔
123
    func: FunctionType, module_names: Optional[set[str]] = None
124
) -> FunctionType: ...
125

126

127
@overload
10✔
128
def check_contracts(func: Class, module_names: Optional[set[str]] = None) -> Class: ...
10✔
129

130

131
def check_contracts(
10✔
132
    func_or_class: Union[Class, FunctionType], module_names: Optional[set[str]] = None
133
) -> Union[Class, FunctionType]:
134
    """A decorator to enable contract checking for a function or class.
135

136
    When used with a class, all methods defined within the class have contract checking enabled.
137
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
138

139
    Example:
140

141
        >>> from python_ta.contracts import check_contracts
142
        >>> @check_contracts
143
        ... def divide(x: int, y: int) -> int:
144
        ...     \"\"\"Return x // y.
145
        ...
146
        ...     Preconditions:
147
        ...        - y != 0
148
        ...     \"\"\"
149
        ...     return x // y
150
    """
151
    if not ENABLE_CONTRACT_CHECKING:
10✔
152
        return func_or_class
10✔
153

154
    if module_names is not None and func_or_class.__module__ not in module_names:
10✔
155
        _debug(
10✔
156
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
157
        )
158
        return func_or_class
10✔
159
    elif inspect.isroutine(func_or_class):
10✔
160
        return _enable_function_contracts(func_or_class)
10✔
161
    elif inspect.isclass(func_or_class):
10✔
162
        add_class_invariants(func_or_class)
10✔
163
        return func_or_class
10✔
164
    else:
165
        # Default action
166
        return func_or_class
×
167

168

169
def add_class_invariants(klass: type) -> None:
10✔
170
    """Modify the given class to check representation invariants and method contracts."""
171
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
10✔
172
        # This means the class has already been decorated
173
        return
×
174

175
    _set_invariants(klass)
10✔
176

177
    klass_mod = _get_module(klass)
10✔
178
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
10✔
179

180
    def new_setattr(self: klass, name: str, value: Any) -> None:
10✔
181
        """Set the value of the given attribute on self to the given value.
182

183
        Check representation invariants for this class when not within an instance method of the class.
184
        """
185
        if not ENABLE_CONTRACT_CHECKING:
10✔
186
            super(klass, self).__setattr__(name, value)
10✔
187
            return
10✔
188

189
        nonlocal cls_annotations
190
        if cls_annotations is None:
10✔
191
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
10✔
192

193
        if name in cls_annotations:
10✔
194
            try:
10✔
195
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
10✔
196
                check_type(
10✔
197
                    value,
198
                    cls_annotations[name],
199
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
200
                )
201
            except TypeCheckError:
10✔
202
                raise AssertionError(
10✔
203
                    f"Value {_display_value(value)} did not match type annotation for attribute "
204
                    f"{name}: {_display_annotation(cls_annotations[name])}"
205
                ) from None
206
        original_attr_value_exists = False
10✔
207
        original_attr_value = None
10✔
208
        if hasattr(super(klass, self), name):
10✔
209
            original_attr_value_exists = True
×
210
            original_attr_value = super(klass, self).__getattribute__(name)
×
211
        super(klass, self).__setattr__(name, value)
10✔
212
        frame_locals = inspect.currentframe().f_back.f_locals
10✔
213
        if self is not frame_locals.get("self"):
10✔
214
            # Only validating if the attribute is not being set in a instance/class method
215
            if klass_mod is not None:
10✔
216
                try:
10✔
217
                    _check_invariants(self, klass, klass_mod.__dict__)
10✔
218
                except PyTAContractError as e:
10✔
219
                    if original_attr_value_exists:
10✔
220
                        super(klass, self).__setattr__(name, original_attr_value)
×
221
                    else:
222
                        super(klass, self).__delattr__(name)
10✔
223
                    raise AssertionError(str(e)) from None
10✔
224

225
    for attr, value in klass.__dict__.items():
10✔
226
        if inspect.isroutine(value):
10✔
227
            if isinstance(value, (staticmethod, classmethod)):
10✔
228
                # Don't check rep invariants for staticmethod and classmethod
229
                setattr(klass, attr, check_contracts(value))
10✔
230
            else:
231
                setattr(klass, attr, _instance_method_wrapper(value, klass))
10✔
232

233
    klass.__setattr__ = new_setattr
10✔
234

235

236
def _check_function_contracts(wrapped, instance, args, kwargs):
10✔
237
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
10✔
238
    if instance is not None:
10✔
239
        klass_mod = _get_module(type(instance))
10✔
240
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
10✔
241
    else:
242
        annotations = typing.get_type_hints(wrapped)
10✔
243
    args_with_self = args if instance is None else (instance,) + args
10✔
244

245
    # Check function parameter types
246
    for arg, param in zip(args_with_self, params):
10✔
247
        if param in annotations:
10✔
248
            try:
10✔
249
                _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
10✔
250
                if STRICT_NUMERIC_TYPES:
10✔
251
                    check_type_strict(param, arg, annotations[param])
10✔
252
                else:
253
                    check_type(arg, annotations[param])
10✔
254
            except (TypeError, TypeCheckError):
10✔
255
                additional_suggestions = _get_argument_suggestions(arg, annotations[param])
10✔
256

257
                raise PyTAContractError(
10✔
258
                    f"{wrapped.__name__} argument {_display_value(arg)} did not match type "
259
                    f"annotation for parameter {param}: {_display_annotation(annotations[param])}"
260
                    + (f"\n{additional_suggestions}" if additional_suggestions else "")
261
                )
262

263
    function_locals = dict(zip(params, args_with_self))
10✔
264

265
    # Check bounded function
266
    if hasattr(wrapped, "__self__"):
10✔
267
        target = wrapped.__func__
10✔
268
    else:
269
        target = wrapped
10✔
270

271
    # Check function preconditions
272
    if not hasattr(target, "__preconditions__"):
10✔
273
        target.__preconditions__: list[tuple[str, CodeType]] = []
10✔
274
        preconditions = parse_assertions(wrapped)
10✔
275
        for precondition in preconditions:
10✔
276
            try:
10✔
277
                compiled = compile(precondition, "<string>", "eval")
10✔
278
            except:
10✔
279
                _debug(
10✔
280
                    f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
281
                )
282
                continue
10✔
283
            target.__preconditions__.append((precondition, compiled))
10✔
284

285
    if ENABLE_CONTRACT_CHECKING:
10✔
286
        _check_assertions(wrapped, function_locals)
10✔
287

288
    # Check return type
289
    r = wrapped(*args, **kwargs)
10✔
290
    if "return" in annotations:
10✔
291
        return_type = annotations["return"]
10✔
292
        try:
10✔
293
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
10✔
294
            if STRICT_NUMERIC_TYPES:
10✔
295
                check_type_strict("return", r, return_type)
10✔
296
            else:
297
                check_type(r, return_type)
10✔
298
        except (TypeError, TypeCheckError):
10✔
299
            raise PyTAContractError(
10✔
300
                f"{wrapped.__name__}'s return value {_display_value(r)} did not match "
301
                f"return type annotation {_display_annotation(return_type)}"
302
            )
303

304
    # Check function postconditions
305
    if not hasattr(target, "__postconditions__"):
10✔
306
        target.__postconditions__: list[tuple[str, CodeType, str]] = []
10✔
307
        return_val_var_name = _get_legal_return_val_var_name(
10✔
308
            {**wrapped.__globals__, **function_locals}
309
        )
310
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
10✔
311
        for postcondition in postconditions:
10✔
312
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
10✔
313
            try:
10✔
314
                compiled = compile(assertion, "<string>", "eval")
10✔
315
            except:
×
316
                _debug(
×
317
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
318
                )
319
                continue
×
320
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
10✔
321

322
    if ENABLE_CONTRACT_CHECKING:
10✔
323
        _check_assertions(
10✔
324
            wrapped,
325
            function_locals,
326
            function_return_val=r,
327
            condition_type="postcondition",
328
        )
329

330
    return r
10✔
331

332

333
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
10✔
334
    """
335
    Ensure that `value` matches ``expected_type`` with strict type checking.
336

337
    This function enforces strict type distinctions within the numeric hierarchy (bool, int, float,
338
    complex), ensuring that the type of value is exactly the same as expected_type, not merely a subtype.
339
    """
340
    if not ENABLE_CONTRACT_CHECKING:
10✔
341
        return
10✔
342
    try:
10✔
343
        _check_inner_type(argname, value, expected_type)
10✔
344
    except (TypeError, TypeCheckError):
10✔
345
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
346

347

348
def _check_inner_type(argname: str, value: Any, expected_type: type) -> None:
10✔
349
    """
350
    Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
351
    collections (list[int], dicts[float]) and Union types (bool | int).
352
    """
353
    inner_types = get_args(expected_type)
10✔
354
    outer_type = get_origin(expected_type)
10✔
355
    if outer_type is None:
10✔
356
        if (
10✔
357
            (type(value) is bool and expected_type in {int, float, complex})
358
            or (type(value) is int and expected_type in {float, complex})
359
            or (type(value) is float and expected_type is complex)
360
        ):
361
            raise TypeError(
10✔
362
                f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
363
            )
364
        else:
365
            check_type(
10✔
366
                value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
367
            )
368
    elif outer_type is typing.Union:
10✔
369
        for inner_type in inner_types:
10✔
370
            try:
10✔
371
                _check_inner_type(argname, value, inner_type)
10✔
372
                return
10✔
373
            except (TypeError, TypeCheckError):
10✔
374
                pass
10✔
375
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
376
    elif outer_type in {list, set}:
10✔
377
        if isinstance(value, outer_type):
10✔
378
            for item in value:
10✔
379
                _check_inner_type(argname, item, inner_types[0])
10✔
380
        else:
381
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
382
    elif outer_type is dict:
10✔
383
        if isinstance(value, dict):
10✔
384
            for key, item in value.items():
10✔
385
                _check_inner_type(argname, key, inner_types[0])
10✔
386
                _check_inner_type(argname, item, inner_types[1])
10✔
387
        else:
NEW
388
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
×
389
    elif outer_type is tuple:
10✔
390
        if isinstance(value, tuple) and len(inner_types) == 2 and inner_types[1] is Ellipsis:
10✔
391
            for item in value:
10✔
392
                _check_inner_type(argname, item, inner_types[0])
10✔
393
        elif isinstance(value, tuple) and len(value) == len(inner_types):
10✔
394
            for i, item in enumerate(value):
10✔
395
                _check_inner_type(argname, item, inner_types[i])
10✔
396
        else:
397
            raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
398
    else:
UNCOV
399
        check_type(
×
400
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
401
        )
402

403

404
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
10✔
405
    """Returns potential suggestions for the given arg and its annotation"""
406
    try:
10✔
407
        if isinstance(arg, type) and issubclass(arg, annotation):
10✔
408
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
10✔
409
    except TypeError:
10✔
410
        pass
10✔
411

412
    return ""
10✔
413

414

415
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
10✔
416
    @wrapt.decorator
10✔
417
    def wrapper(wrapped, instance, args, kwargs):
10✔
418
        try:
10✔
419
            r = _check_function_contracts(wrapped, instance, args, kwargs)
10✔
420
            if _instance_init_in_callstack(instance):
10✔
421
                return r
10✔
422
            _check_class_type_annotations(klass, instance)
10✔
423
            klass_mod = _get_module(klass)
10✔
424
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
10✔
425
                _check_invariants(instance, klass, klass_mod.__dict__)
10✔
426
        except PyTAContractError as e:
10✔
427
            raise AssertionError(str(e)) from None
10✔
428
        else:
429
            return r
10✔
430

431
    return wrapper(wrapped)
10✔
432

433

434
def _instance_init_in_callstack(instance: Any) -> bool:
10✔
435
    """Return whether instance's init is part of the current callstack
436

437
    Note: due to the nature of the check, externally defined __init__ functions with
438
    'self' defined as the first parameter may pass this check.
439
    """
440
    frame = inspect.currentframe().f_back
10✔
441
    while frame:
10✔
442
        frame_context_name = inspect.getframeinfo(frame).function
10✔
443
        frame_context_self = frame.f_locals.get("self")
10✔
444
        frame_context_vars = frame.f_code.co_varnames
10✔
445
        if (
10✔
446
            frame_context_name == "__init__"
447
            and frame_context_self is instance
448
            and frame_context_vars[0] == "self"
449
        ):
450
            return True
10✔
451
        frame = frame.f_back
10✔
452
    return False
10✔
453

454

455
def _check_class_type_annotations(klass: type, instance: Any) -> None:
10✔
456
    """Check that the type annotations for the class still hold.
457

458
    Precondition:
459
        - isinstance(instance, klass)
460
    """
461
    klass_mod = _get_module(klass)
10✔
462
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
10✔
463

464
    for attr, annotation in cls_annotations.items():
10✔
465
        value = getattr(instance, attr)
10✔
466
        try:
10✔
467
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
10✔
468
            check_type(
10✔
469
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
470
            )
471
        except TypeCheckError:
10✔
472
            raise AssertionError(
10✔
473
                f"{_display_value(value)} did not match type annotation for attribute {attr}: "
474
                f"{_display_annotation(annotation)}"
475
            )
476

477

478
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
10✔
479
    """Check that the representation invariants for the instance are satisfied."""
480
    if hasattr(instance, "__pyta_currently_checking"):
10✔
481
        # If already checking invariants for this instance, skip to avoid infinite recursion
482
        return
10✔
483

484
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
10✔
485

486
    rep_invariants = getattr(klass, "__representation_invariants__", set())
10✔
487

488
    try:
10✔
489
        for invariant, compiled in rep_invariants:
10✔
490
            try:
10✔
491
                _debug(
10✔
492
                    "Checking representation invariant for "
493
                    f"{instance.__class__.__qualname__}: {invariant}"
494
                )
495
                check = eval(compiled, {**global_scope, "self": instance})
10✔
496
            except AssertionError as e:
10✔
497
                raise AssertionError(str(e)) from None
10✔
498
            except:
×
499
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
500
            else:
501
                if not check:
10✔
502
                    curr_attributes = ", ".join(
10✔
503
                        f"{k}: {_display_value(v)}"
504
                        for k, v in vars(instance).items()
505
                        if k != "__pyta_currently_checking"
506
                    )
507

508
                    curr_attributes = "{" + curr_attributes + "}"
10✔
509

510
                    raise PyTAContractError(
10✔
511
                        f'"{instance.__class__.__name__}" representation invariant "{invariant}" was violated for'
512
                        f" instance attributes {curr_attributes}"
513
                    )
514

515
    finally:
516
        delattr(instance, "__pyta_currently_checking")
10✔
517

518

519
def _get_legal_return_val_var_name(var_dict: dict) -> str:
10✔
520
    """
521
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
522
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
523
    postconditions.
524
    """
525
    legal_var_name = "__function_return_value__"
10✔
526

527
    while legal_var_name in var_dict:
10✔
528
        legal_var_name += "_"
×
529

530
    return legal_var_name
10✔
531

532

533
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
10✔
534
    """
535
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
536
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
537

538
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
539
    """
540

541
    if FUNCTION_RETURN_VALUE in assertion:
10✔
542
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
10✔
543
    return assertion
10✔
544

545

546
def _check_assertions(
10✔
547
    wrapped: Callable[..., Any],
548
    function_locals: dict,
549
    condition_type: str = "precondition",
550
    function_return_val: Any = None,
551
) -> None:
552
    """Check that the given assertions are still satisfied."""
553
    # Check bounded function
554
    if hasattr(wrapped, "__self__"):
10✔
555
        target = wrapped.__func__
10✔
556
    else:
557
        target = wrapped
10✔
558
    assertions = []
10✔
559
    if condition_type == "precondition":
10✔
560
        assertions = target.__preconditions__
10✔
561
    elif condition_type == "postcondition":
10✔
562
        assertions = target.__postconditions__
10✔
563
    for assertion_str, compiled, *return_val_var_name in assertions:
10✔
564
        return_val_dict = {}
10✔
565
        if condition_type == "postcondition":
10✔
566
            return_val_dict = {return_val_var_name[0]: function_return_val}
10✔
567
        try:
10✔
568
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
10✔
569
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
10✔
570
        except AssertionError as e:
10✔
571
            raise AssertionError(str(e)) from None
10✔
572
        except:
×
573
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
574
        else:
575
            if not check:
10✔
576
                arg_string = ", ".join(
10✔
577
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
578
                )
579
                arg_string = "{" + arg_string + "}"
10✔
580

581
                return_val_string = ""
10✔
582

583
                if condition_type == "postcondition":
10✔
584
                    return_val_string = f"and return value {function_return_val}"
10✔
585
                raise PyTAContractError(
10✔
586
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
587
                    f"violated for arguments {arg_string} {return_val_string}"
588
                )
589

590

591
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> list[str]:
10✔
592
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
593

594
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
595

596
    Currently only supports two forms:
597

598
    1. A single line of the form "<parse_token>: <cond>"
599
    2. A group of lines starting with "<parse_token>s:", where each subsequent
600
       line is of the form "- <cond>". Each line is considered a separate condition.
601
       The lines can be separated by blank lines, but no other text.
602
    """
603
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
10✔
604
        # Check if obj is an astroid node
605
        docstring = obj.doc_node.value
10✔
606
    else:
607
        docstring = getattr(obj, "__doc__") or ""
10✔
608
    lines = [line.strip() for line in docstring.split("\n")]
10✔
609
    assertion_lines = [
10✔
610
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
611
    ]
612

613
    if assertion_lines == []:
10✔
614
        return []
10✔
615

616
    first = assertion_lines[0]
10✔
617

618
    if lines[first].startswith(parse_token + ":"):
10✔
619
        return [lines[first][len(parse_token + ":") :].strip()]
10✔
620
    elif lines[first].startswith(parse_token + "s:"):
10✔
621
        assertions = []
10✔
622
        for line in lines[first + 1 :]:
10✔
623
            if line.startswith("-"):
10✔
624
                assertion = line[1:].strip()
10✔
625
                if hasattr(obj, "__qualname__"):
10✔
626
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
10✔
627
                assertions.append(assertion)
10✔
628
            elif line != "":
10✔
629
                break
×
630
        return assertions
10✔
631
    else:
632
        return []
×
633

634

635
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
10✔
636
    """Return a human-friendly representation of the given value.
637

638
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
639

640
    Preconditions:
641
        - max_length >= 5
642
    """
643
    s = repr(value)
10✔
644
    if not DEBUG_CONTRACTS and len(s) > max_length:
10✔
645
        i = (max_length - 3) // 2
×
646
        return s[:i] + "..." + s[-i:]
×
647
    else:
648
        return s
10✔
649

650

651
def _display_annotation(annotation: Any) -> str:
10✔
652
    """Return a human-friendly representation of the given type annotation.
653

654
    >>> _display_annotation(int)
655
    'int'
656
    >>> _display_annotation(list[int])
657
    'list[int]'
658
    >>> from typing import List
659
    >>> _display_annotation(List[int])
660
    'typing.List[int]'
661
    """
662
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
10✔
663
        return "None"
×
664
    if hasattr(annotation, "__origin__"):  # Generic type annotations
10✔
665
        return repr(annotation)
10✔
666
    elif hasattr(annotation, "__name__"):
10✔
667
        return annotation.__name__
10✔
668
    else:
669
        return repr(annotation)
×
670

671

672
def _get_module(obj: Any) -> ModuleType:
10✔
673
    """Return the module where obj was defined (normally obj.__module__).
674

675
    NOTE: this function defines a special case when using PyCharm and the file
676
    defining the object is "Run in Python Console". In this case, the pydevd runner
677
    renames the '__main__' module to 'pydev_umd', and so we need to access that
678
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
679
    to False.
680
    """
681
    module_name = obj.__module__
10✔
682
    module = sys.modules[module_name]
10✔
683

684
    if (
10✔
685
        module_name != "__main__"
686
        or not RENAME_MAIN_TO_PYDEV_UMD
687
        or _PYDEV_UMD_NAME not in sys.modules
688
    ):
689
        return module
10✔
690

691
    # Get a function/class name to check whether it is defined in the module
692
    if isinstance(obj, (FunctionType, type)):
×
693
        name = obj.__name__
×
694
    else:
695
        # For any other type of object, be conservative and just return the module
696
        return module
×
697

698
    if name in vars(module):
×
699
        return module
×
700
    else:
701
        return sys.modules[_PYDEV_UMD_NAME]
×
702

703

704
def _debug(msg: str) -> None:
10✔
705
    """Display a debugging message.
706

707
    Do nothing if DEBUG_CONTRACTS is False.
708
    """
709
    if not DEBUG_CONTRACTS:
10✔
710
        return
10✔
711
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
10✔
712
    logging.debug(msg)
10✔
713

714

715
def _set_invariants(klass: type) -> None:
10✔
716
    """Retrieve and set the representation invariants of this class"""
717
    # Update representation invariants from this class' docstring and those of its superclasses.
718
    rep_invariants: list[tuple[str, CodeType]] = []
10✔
719

720
    # Iterate over all inherited classes except builtins
721
    for cls in reversed(klass.__mro__):
10✔
722
        if "__representation_invariants__" in cls.__dict__:
10✔
723
            rep_invariants.extend(cls.__representation_invariants__)
10✔
724
        elif cls.__module__ != "builtins":
10✔
725
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
10✔
726
            # Try compiling assertions
727
            for assertion in assertions:
10✔
728
                try:
10✔
729
                    compiled = compile(assertion, "<string>", "eval")
10✔
730
                except:
×
731
                    _debug(
×
732
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
733
                    )
734
                    continue
×
735
                rep_invariants.append((assertion, compiled))
10✔
736

737
    setattr(klass, "__representation_invariants__", rep_invariants)
10✔
738

739

740
def validate_invariants(obj: object) -> None:
10✔
741
    """Check that the representation invariants of obj are satisfied."""
742
    klass = obj.__class__
10✔
743
    klass_mod = _get_module(klass)
10✔
744

745
    try:
10✔
746
        _check_invariants(obj, klass, klass_mod.__dict__)
10✔
747
    except PyTAContractError as e:
10✔
748
        raise AssertionError(str(e)) from None
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc