• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyta-uoft / pyta / 11606083184

31 Oct 2024 04:21AM UTC coverage: 91.958% (+0.06%) from 91.895%
11606083184

Pull #1103

github

web-flow
Merge b4565b3f4 into a307a7e8a
Pull Request #1103: Enable Strict Type Checking for Complex Data Types in check_contracts

36 of 36 new or added lines in 1 file covered. (100.0%)

20 existing lines in 1 file now uncovered.

3076 of 3345 relevant lines covered (91.96%)

9.18 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.91
/python_ta/contracts/__init__.py
1
"""This module provides the functionality for PythonTA contracts.
2

3
Representation invariants, preconditions, and postconditions are parsed, compiled, and stored.
4
Below are some notes on how they are stored.
5
    - Representation invariants are stored in a class attribute __representation_invariants__
6
    as a list [(assertion, compiled)].
7
    - Preconditions are stored in an attribute __preconditions__ of the function as a list
8
    [(assertion, compiled)].
9
    - Postconditions are stored in an attribute __postconditions__ of the function as a list
10
    [(assertion, compiled, return_val_var_name)].
11
"""
12

13
from __future__ import annotations
10✔
14

15
import inspect
10✔
16
import logging
10✔
17
import sys
10✔
18
import typing
10✔
19
from types import CodeType, FunctionType, ModuleType
10✔
20
from typing import (
10✔
21
    Any,
22
    Callable,
23
    Collection,
24
    Optional,
25
    TypeVar,
26
    Union,
27
    get_args,
28
    get_origin,
29
    overload,
30
)
31

32
import wrapt
10✔
33
from typeguard import CollectionCheckStrategy, TypeCheckError, check_type
10✔
34

35
# Configuration options
36

37
ENABLE_CONTRACT_CHECKING = True
10✔
38
"""
8✔
39
Set to True to enable contract checking.
40
"""
41

42
DEBUG_CONTRACTS = False
10✔
43
"""
8✔
44
Set to True to display debugging messages when checking contracts.
45
"""
46

47
RENAME_MAIN_TO_PYDEV_UMD = True
10✔
48
"""
8✔
49
Set to False to disable workaround for PyCharm's "Run File in Python Console" action.
50
In most cases you should not need to change this!
51
"""
52

53
STRICT_NUMERIC_TYPES = True
10✔
54
"""
8✔
55
Set to False to allow more specific numeric types to be accepted by more general type annotations.
56
"""
57

58
_PYDEV_UMD_NAME = "pydev_umd"
10✔
59

60

61
_DEFAULT_MAX_VALUE_LENGTH = 30
10✔
62
FUNCTION_RETURN_VALUE = "$return_value"
10✔
63

64

65
class PyTAContractError(Exception):
10✔
66
    """Error raised when a PyTA contract assertion is violated."""
67

68

69
def check_all_contracts(*mod_names: str, decorate_main: bool = True) -> None:
10✔
70
    """Automatically check contracts for all functions and classes in the given modules.
71

72
    By default (when called with no arguments), the current module is used.
73

74
    Args:
75
        *mod_names: The names of modules to check contracts for. These modules must have been
76
            previously imported.
77
        decorate_main: True if the module being run (where __name__ == '__main__') should
78
            have contracts checked.
79
    """
80
    if not ENABLE_CONTRACT_CHECKING:
10✔
81
        return
×
82

83
    modules = []
10✔
84
    if decorate_main:
10✔
85
        mod_names = mod_names + ("__main__",)
×
86

87
        # Also add _PYDEV_UMD_NAME, handling when the file is being run in PyCharm
88
        # with the "Run in Python Console" action.
89
        if RENAME_MAIN_TO_PYDEV_UMD:
×
90
            mod_names = mod_names + (_PYDEV_UMD_NAME,)
×
91

92
    for module_name in mod_names:
10✔
93
        modules.append(sys.modules.get(module_name, None))
10✔
94

95
    for module in modules:
10✔
96
        if not module:
10✔
97
            # Module name was passed in incorrectly.
98
            continue
×
99
        for name, value in inspect.getmembers(module):
10✔
100
            if inspect.isfunction(value) or inspect.isclass(value):
10✔
101
                module.__dict__[name] = check_contracts(value, module_names=set(mod_names))
10✔
102

103

104
@wrapt.decorator
10✔
105
def _enable_function_contracts(wrapped, instance, args, kwargs):
10✔
106
    """A decorator that enables checking contracts for a function."""
107
    try:
10✔
108
        if instance is not None and inspect.isclass(instance):
10✔
109
            # This is a class method, so there is no instance.
110
            return _check_function_contracts(wrapped, None, args, kwargs)
10✔
111
        else:
112
            return _check_function_contracts(wrapped, instance, args, kwargs)
10✔
113
    except PyTAContractError as e:
10✔
114
        raise AssertionError(str(e)) from None
10✔
115

116

117
# Wildcard Type Variable
118
Class = TypeVar("Class", bound=type)
10✔
119

120

121
@overload
10✔
122
def check_contracts(
10✔
123
    func: FunctionType, module_names: Optional[set[str]] = None
124
) -> FunctionType: ...
125

126

127
@overload
10✔
128
def check_contracts(func: Class, module_names: Optional[set[str]] = None) -> Class: ...
10✔
129

130

131
def check_contracts(
10✔
132
    func_or_class: Union[Class, FunctionType], module_names: Optional[set[str]] = None
133
) -> Union[Class, FunctionType]:
134
    """A decorator to enable contract checking for a function or class.
135

136
    When used with a class, all methods defined within the class have contract checking enabled.
137
    If module_names is not None, only functions or classes defined in a module whose name is in module_names are checked.
138

139
    Example:
140

141
        >>> from python_ta.contracts import check_contracts
142
        >>> @check_contracts
143
        ... def divide(x: int, y: int) -> int:
144
        ...     \"\"\"Return x // y.
145
        ...
146
        ...     Preconditions:
147
        ...        - y != 0
148
        ...     \"\"\"
149
        ...     return x // y
150
    """
151
    if not ENABLE_CONTRACT_CHECKING:
10✔
152
        return func_or_class
10✔
153

154
    if module_names is not None and func_or_class.__module__ not in module_names:
10✔
155
        _debug(
10✔
156
            f"Warning: skipping contract check for {func_or_class.__name__} defined in {func_or_class.__module__} because module is not included as an argument."
157
        )
158
        return func_or_class
10✔
159
    elif inspect.isroutine(func_or_class):
10✔
160
        return _enable_function_contracts(func_or_class)
10✔
161
    elif inspect.isclass(func_or_class):
10✔
162
        add_class_invariants(func_or_class)
10✔
163
        return func_or_class
10✔
164
    else:
165
        # Default action
166
        return func_or_class
×
167

168

169
def add_class_invariants(klass: type) -> None:
10✔
170
    """Modify the given class to check representation invariants and method contracts."""
171
    if not ENABLE_CONTRACT_CHECKING or "__representation_invariants__" in klass.__dict__:
10✔
172
        # This means the class has already been decorated
173
        return
×
174

175
    _set_invariants(klass)
10✔
176

177
    klass_mod = _get_module(klass)
10✔
178
    cls_annotations = None  # This is a cached value set the first time new_setattr is called
10✔
179

180
    def new_setattr(self: klass, name: str, value: Any) -> None:
10✔
181
        """Set the value of the given attribute on self to the given value.
182

183
        Check representation invariants for this class when not within an instance method of the class.
184
        """
185
        if not ENABLE_CONTRACT_CHECKING:
10✔
186
            super(klass, self).__setattr__(name, value)
10✔
187
            return
10✔
188

189
        nonlocal cls_annotations
190
        if cls_annotations is None:
10✔
191
            cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
10✔
192

193
        if name in cls_annotations:
10✔
194
            try:
10✔
195
                _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
10✔
196
                check_type(
10✔
197
                    value,
198
                    cls_annotations[name],
199
                    collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS,
200
                )
201
            except TypeCheckError:
10✔
202
                raise AssertionError(
10✔
203
                    f"Value {_display_value(value)} did not match type annotation for attribute "
204
                    f"{name}: {_display_annotation(cls_annotations[name])}"
205
                ) from None
206
        original_attr_value_exists = False
10✔
207
        original_attr_value = None
10✔
208
        if hasattr(super(klass, self), name):
10✔
209
            original_attr_value_exists = True
×
210
            original_attr_value = super(klass, self).__getattribute__(name)
×
211
        super(klass, self).__setattr__(name, value)
10✔
212
        frame_locals = inspect.currentframe().f_back.f_locals
10✔
213
        if self is not frame_locals.get("self"):
10✔
214
            # Only validating if the attribute is not being set in a instance/class method
215
            if klass_mod is not None:
10✔
216
                try:
10✔
217
                    _check_invariants(self, klass, klass_mod.__dict__)
10✔
218
                except PyTAContractError as e:
10✔
219
                    if original_attr_value_exists:
10✔
220
                        super(klass, self).__setattr__(name, original_attr_value)
×
221
                    else:
222
                        super(klass, self).__delattr__(name)
10✔
223
                    raise AssertionError(str(e)) from None
10✔
224

225
    for attr, value in klass.__dict__.items():
10✔
226
        if inspect.isroutine(value):
10✔
227
            if isinstance(value, (staticmethod, classmethod)):
10✔
228
                # Don't check rep invariants for staticmethod and classmethod
229
                setattr(klass, attr, check_contracts(value))
10✔
230
            else:
231
                setattr(klass, attr, _instance_method_wrapper(value, klass))
10✔
232

233
    klass.__setattr__ = new_setattr
10✔
234

235

236
def _check_function_contracts(wrapped, instance, args, kwargs):
10✔
237
    params = wrapped.__code__.co_varnames[: wrapped.__code__.co_argcount]
10✔
238
    if instance is not None:
10✔
239
        klass_mod = _get_module(type(instance))
10✔
240
        annotations = typing.get_type_hints(wrapped, globalns=klass_mod.__dict__)
10✔
241
    else:
242
        annotations = typing.get_type_hints(wrapped)
10✔
243
    args_with_self = args if instance is None else (instance,) + args
10✔
244

245
    # Check function parameter types
246
    for arg, param in zip(args_with_self, params):
10✔
247
        if param in annotations:
10✔
248
            try:
10✔
249
                _debug(f"Checking type of parameter {param} in call to {wrapped.__qualname__}")
10✔
250
                if STRICT_NUMERIC_TYPES:
10✔
251
                    check_type_strict(param, arg, annotations[param])
10✔
252
                else:
253
                    check_type(arg, annotations[param])
10✔
254
            except (TypeError, TypeCheckError):
10✔
255
                additional_suggestions = _get_argument_suggestions(arg, annotations[param])
10✔
256

257
                raise PyTAContractError(
10✔
258
                    f"{wrapped.__name__} argument {_display_value(arg)} did not match type "
259
                    f"annotation for parameter {param}: {_display_annotation(annotations[param])}"
260
                    + (f"\n{additional_suggestions}" if additional_suggestions else "")
261
                )
262

263
    function_locals = dict(zip(params, args_with_self))
10✔
264

265
    # Check bounded function
266
    if hasattr(wrapped, "__self__"):
10✔
267
        target = wrapped.__func__
10✔
268
    else:
269
        target = wrapped
10✔
270

271
    # Check function preconditions
272
    if not hasattr(target, "__preconditions__"):
10✔
273
        target.__preconditions__: list[tuple[str, CodeType]] = []
10✔
274
        preconditions = parse_assertions(wrapped)
10✔
275
        for precondition in preconditions:
10✔
276
            try:
10✔
277
                compiled = compile(precondition, "<string>", "eval")
10✔
278
            except:
10✔
279
                _debug(
10✔
280
                    f"Warning: precondition {precondition} could not be parsed as a valid Python expression"
281
                )
282
                continue
10✔
283
            target.__preconditions__.append((precondition, compiled))
10✔
284

285
    if ENABLE_CONTRACT_CHECKING:
10✔
286
        _check_assertions(wrapped, function_locals)
10✔
287

288
    # Check return type
289
    r = wrapped(*args, **kwargs)
10✔
290
    if "return" in annotations:
10✔
291
        return_type = annotations["return"]
10✔
292
        try:
10✔
293
            _debug(f"Checking return type from call to {wrapped.__qualname__}")
10✔
294
            if STRICT_NUMERIC_TYPES:
10✔
295
                check_type_strict("return", r, return_type)
10✔
296
            else:
297
                check_type(r, return_type)
10✔
298
        except (TypeError, TypeCheckError):
10✔
299
            raise PyTAContractError(
10✔
300
                f"{wrapped.__name__}'s return value {_display_value(r)} did not match "
301
                f"return type annotation {_display_annotation(return_type)}"
302
            )
303

304
    # Check function postconditions
305
    if not hasattr(target, "__postconditions__"):
10✔
306
        target.__postconditions__: list[tuple[str, CodeType, str]] = []
10✔
307
        return_val_var_name = _get_legal_return_val_var_name(
10✔
308
            {**wrapped.__globals__, **function_locals}
309
        )
310
        postconditions = parse_assertions(wrapped, parse_token="Postcondition")
10✔
311
        for postcondition in postconditions:
10✔
312
            assertion = _replace_return_val_assertion(postcondition, return_val_var_name)
10✔
313
            try:
10✔
314
                compiled = compile(assertion, "<string>", "eval")
10✔
315
            except:
×
316
                _debug(
×
317
                    f"Warning: postcondition {postcondition} could not be parsed as a valid Python expression"
318
                )
319
                continue
×
320
            target.__postconditions__.append((postcondition, compiled, return_val_var_name))
10✔
321

322
    if ENABLE_CONTRACT_CHECKING:
10✔
323
        _check_assertions(
10✔
324
            wrapped,
325
            function_locals,
326
            function_return_val=r,
327
            condition_type="postcondition",
328
        )
329

330
    return r
10✔
331

332

333
def check_type_strict(argname: str, value: Any, expected_type: type) -> None:
10✔
334
    """Ensure that ``value`` matches ``expected_type``.
335

336
    Differentiates between:
337
        - float vs. int
338
        - bool vs. int
339
    """
340
    if not ENABLE_CONTRACT_CHECKING:
10✔
341
        pass
8✔
342
    try:
10✔
343
        _check_inner_type(argname, value, expected_type)
10✔
344
    except (TypeError, TypeCheckError):
10✔
345
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
346

347

348
def _check_inner_type(argname: str, value: Any, expected_type: type):
10✔
349
    """
350
    Recursively checks if `value` matches `expected_type` for strict type validation, specifically supports checking
351
    collections (list[int], dicts[float]) and Union types (bool | int).
352
    """
353
    inner_types = get_args(expected_type)
10✔
354
    outter_type = get_origin(expected_type)
10✔
355
    if outter_type is None:
10✔
356
        if (
10✔
357
            (type(value) is bool and expected_type in {int, float, complex})
358
            or (type(value) is int and expected_type in {float, complex})
359
            or (type(value) is float and expected_type is complex)
360
        ):
361
            raise TypeError(
10✔
362
                f"type of {argname} must be {expected_type}; got {type(value).__name__} instead"
363
            )
364
        check_type(
10✔
365
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
366
        )
367
    elif outter_type is typing.Union:
10✔
368
        for inner_type in inner_types:
10✔
369
            try:
10✔
370
                _check_inner_type(argname, value, inner_type)
10✔
371
                return
10✔
372
            except (TypeError, TypeCheckError):
10✔
373
                pass
10✔
374
        raise TypeError(f"type of {argname} must be {expected_type}; got {value} instead")
10✔
375
    elif isinstance(value, Collection) and not isinstance(value, str):
10✔
376
        if outter_type in {list, set, tuple}:
10✔
377
            for item in value:
10✔
378
                _check_inner_type(argname, item, inner_types[0])
10✔
379
        elif isinstance(value, dict) and outter_type is dict:
10✔
380
            for key, item in value.items():
10✔
381
                _check_inner_type(argname, key, inner_types[0])
10✔
382
                _check_inner_type(argname, item, inner_types[1])
10✔
383

384
    else:
385
        check_type(
10✔
386
            value, expected_type, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
387
        )
388

389

390
def _get_argument_suggestions(arg: Any, annotation: type) -> str:
10✔
391
    """Returns potential suggestions for the given arg and its annotation"""
392
    try:
10✔
393
        if isinstance(arg, type) and issubclass(arg, annotation):
10✔
394
            return "Did you mean {cls}(...) instead of {cls}?".format(cls=arg.__name__)
10✔
395
    except TypeError:
10✔
396
        pass
10✔
397

398
    return ""
10✔
399

400

401
def _instance_method_wrapper(wrapped: Callable, klass: type) -> Callable:
10✔
402
    @wrapt.decorator
10✔
403
    def wrapper(wrapped, instance, args, kwargs):
10✔
404
        try:
10✔
405
            r = _check_function_contracts(wrapped, instance, args, kwargs)
10✔
406
            if _instance_init_in_callstack(instance):
10✔
407
                return r
10✔
408
            _check_class_type_annotations(klass, instance)
10✔
409
            klass_mod = _get_module(klass)
10✔
410
            if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
10✔
411
                _check_invariants(instance, klass, klass_mod.__dict__)
10✔
412
        except PyTAContractError as e:
10✔
413
            raise AssertionError(str(e)) from None
10✔
414
        else:
415
            return r
10✔
416

417
    return wrapper(wrapped)
10✔
418

419

420
def _instance_init_in_callstack(instance: Any) -> bool:
10✔
421
    """Return whether instance's init is part of the current callstack
422

423
    Note: due to the nature of the check, externally defined __init__ functions with
424
    'self' defined as the first parameter may pass this check.
425
    """
426
    frame = inspect.currentframe().f_back
10✔
427
    while frame:
10✔
428
        frame_context_name = inspect.getframeinfo(frame).function
10✔
429
        frame_context_self = frame.f_locals.get("self")
10✔
430
        frame_context_vars = frame.f_code.co_varnames
10✔
431
        if (
10✔
432
            frame_context_name == "__init__"
433
            and frame_context_self is instance
434
            and frame_context_vars[0] == "self"
435
        ):
436
            return True
10✔
437
        frame = frame.f_back
10✔
438
    return False
10✔
439

440

441
def _check_class_type_annotations(klass: type, instance: Any) -> None:
10✔
442
    """Check that the type annotations for the class still hold.
443

444
    Precondition:
445
        - isinstance(instance, klass)
446
    """
447
    klass_mod = _get_module(klass)
10✔
448
    cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
10✔
449

450
    for attr, annotation in cls_annotations.items():
10✔
451
        value = getattr(instance, attr)
10✔
452
        try:
10✔
453
            _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
10✔
454
            check_type(
10✔
455
                value, annotation, collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS
456
            )
457
        except TypeCheckError:
10✔
458
            raise AssertionError(
10✔
459
                f"{_display_value(value)} did not match type annotation for attribute {attr}: "
460
                f"{_display_annotation(annotation)}"
461
            )
462

463

464
def _check_invariants(instance, klass: type, global_scope: dict) -> None:
10✔
465
    """Check that the representation invariants for the instance are satisfied."""
466
    if hasattr(instance, "__pyta_currently_checking"):
10✔
467
        # If already checking invariants for this instance, skip to avoid infinite recursion
468
        return
10✔
469

470
    super(type(instance), instance).__setattr__("__pyta_currently_checking", True)
10✔
471

472
    rep_invariants = getattr(klass, "__representation_invariants__", set())
10✔
473

474
    try:
10✔
475
        for invariant, compiled in rep_invariants:
10✔
476
            try:
10✔
477
                _debug(
10✔
478
                    "Checking representation invariant for "
479
                    f"{instance.__class__.__qualname__}: {invariant}"
480
                )
481
                check = eval(compiled, {**global_scope, "self": instance})
10✔
482
            except AssertionError as e:
10✔
483
                raise AssertionError(str(e)) from None
10✔
UNCOV
484
            except:
×
UNCOV
485
                _debug(f"Warning: could not evaluate representation invariant: {invariant}")
×
486
            else:
487
                if not check:
10✔
488
                    curr_attributes = ", ".join(
10✔
489
                        f"{k}: {_display_value(v)}"
490
                        for k, v in vars(instance).items()
491
                        if k != "__pyta_currently_checking"
492
                    )
493

494
                    curr_attributes = "{" + curr_attributes + "}"
10✔
495

496
                    raise PyTAContractError(
10✔
497
                        f'"{instance.__class__.__name__}" representation invariant "{invariant}" was violated for'
498
                        f" instance attributes {curr_attributes}"
499
                    )
500

501
    finally:
502
        delattr(instance, "__pyta_currently_checking")
10✔
503

504

505
def _get_legal_return_val_var_name(var_dict: dict) -> str:
10✔
506
    """
507
    Add '_' to the end of __function_return_value__ until a variable name that has not been used for any other
508
    variable in the function's scope is created. This is used to refer to the function's return value when evaluating
509
    postconditions.
510
    """
511
    legal_var_name = "__function_return_value__"
10✔
512

513
    while legal_var_name in var_dict:
10✔
UNCOV
514
        legal_var_name += "_"
×
515

516
    return legal_var_name
10✔
517

518

519
def _replace_return_val_assertion(assertion: str, return_val_var_name: Optional[str]) -> str:
10✔
520
    """
521
    Replace FUNCTION_RETURN_VALUE in the assertion with the legal python variable name generated and return the new
522
    assertion. If FUNCTION_RETURN_VALUE does not appear in assertion, then simply return the original assertion.
523

524
    Precondition: If FUNCTION_RETURN_VALUE is in assertion, then return_val_var_name is not None
525
    """
526

527
    if FUNCTION_RETURN_VALUE in assertion:
10✔
528
        return assertion.replace(FUNCTION_RETURN_VALUE, return_val_var_name)
10✔
529
    return assertion
10✔
530

531

532
def _check_assertions(
10✔
533
    wrapped: Callable[..., Any],
534
    function_locals: dict,
535
    condition_type: str = "precondition",
536
    function_return_val: Any = None,
537
) -> None:
538
    """Check that the given assertions are still satisfied."""
539
    # Check bounded function
540
    if hasattr(wrapped, "__self__"):
10✔
541
        target = wrapped.__func__
10✔
542
    else:
543
        target = wrapped
10✔
544
    assertions = []
10✔
545
    if condition_type == "precondition":
10✔
546
        assertions = target.__preconditions__
10✔
547
    elif condition_type == "postcondition":
10✔
548
        assertions = target.__postconditions__
10✔
549
    for assertion_str, compiled, *return_val_var_name in assertions:
10✔
550
        return_val_dict = {}
10✔
551
        if condition_type == "postcondition":
10✔
552
            return_val_dict = {return_val_var_name[0]: function_return_val}
10✔
553
        try:
10✔
554
            _debug(f"Checking {condition_type} for {wrapped.__qualname__}: {assertion_str}")
10✔
555
            check = eval(compiled, {**wrapped.__globals__, **function_locals, **return_val_dict})
10✔
556
        except AssertionError as e:
10✔
557
            raise AssertionError(str(e)) from None
10✔
UNCOV
558
        except:
×
UNCOV
559
            _debug(f"Warning: could not evaluate {condition_type}: {assertion_str}")
×
560
        else:
561
            if not check:
10✔
562
                arg_string = ", ".join(
10✔
563
                    f"{k}: {_display_value(v)}" for k, v in function_locals.items()
564
                )
565
                arg_string = "{" + arg_string + "}"
10✔
566

567
                return_val_string = ""
10✔
568

569
                if condition_type == "postcondition":
10✔
570
                    return_val_string = f"and return value {function_return_val}"
10✔
571
                raise PyTAContractError(
10✔
572
                    f'{wrapped.__name__} {condition_type} "{assertion_str}" was '
573
                    f"violated for arguments {arg_string} {return_val_string}"
574
                )
575

576

577
def parse_assertions(obj: Any, parse_token: str = "Precondition") -> list[str]:
10✔
578
    """Return a list of preconditions/postconditions/representation invariants parsed from the given entity's docstring.
579

580
    Uses parse_token to determine what to look for. parse_token defaults to Precondition.
581

582
    Currently only supports two forms:
583

584
    1. A single line of the form "<parse_token>: <cond>"
585
    2. A group of lines starting with "<parse_token>s:", where each subsequent
586
       line is of the form "- <cond>". Each line is considered a separate condition.
587
       The lines can be separated by blank lines, but no other text.
588
    """
589
    if hasattr(obj, "doc_node") and obj.doc_node is not None:
10✔
590
        # Check if obj is an astroid node
591
        docstring = obj.doc_node.value
10✔
592
    else:
593
        docstring = getattr(obj, "__doc__") or ""
10✔
594
    lines = [line.strip() for line in docstring.split("\n")]
10✔
595
    assertion_lines = [
10✔
596
        i for i, line in enumerate(lines) if line.lower().startswith(parse_token.lower())
597
    ]
598

599
    if assertion_lines == []:
10✔
600
        return []
10✔
601

602
    first = assertion_lines[0]
10✔
603

604
    if lines[first].startswith(parse_token + ":"):
10✔
605
        return [lines[first][len(parse_token + ":") :].strip()]
10✔
606
    elif lines[first].startswith(parse_token + "s:"):
10✔
607
        assertions = []
10✔
608
        for line in lines[first + 1 :]:
10✔
609
            if line.startswith("-"):
10✔
610
                assertion = line[1:].strip()
10✔
611
                if hasattr(obj, "__qualname__"):
10✔
612
                    _debug(f"Adding assertion to {obj.__qualname__}: {assertion}")
10✔
613
                assertions.append(assertion)
10✔
614
            elif line != "":
10✔
UNCOV
615
                break
×
616
        return assertions
10✔
617
    else:
UNCOV
618
        return []
×
619

620

621
def _display_value(value: Any, max_length: int = _DEFAULT_MAX_VALUE_LENGTH) -> str:
10✔
622
    """Return a human-friendly representation of the given value.
623

624
    If DEBUG_CONTRACTS is False, truncate long strings to max_length characters.
625

626
    Preconditions:
627
        - max_length >= 5
628
    """
629
    s = repr(value)
10✔
630
    if not DEBUG_CONTRACTS and len(s) > max_length:
10✔
UNCOV
631
        i = (max_length - 3) // 2
×
UNCOV
632
        return s[:i] + "..." + s[-i:]
×
633
    else:
634
        return s
10✔
635

636

637
def _display_annotation(annotation: Any) -> str:
10✔
638
    """Return a human-friendly representation of the given type annotation.
639

640
    >>> _display_annotation(int)
641
    'int'
642
    >>> _display_annotation(list[int])
643
    'list[int]'
644
    >>> from typing import List
645
    >>> _display_annotation(List[int])
646
    'typing.List[int]'
647
    """
648
    if annotation is type(None):  # Use 'None' instead of 'NoneType'
10✔
UNCOV
649
        return "None"
×
650
    if hasattr(annotation, "__origin__"):  # Generic type annotations
10✔
651
        return repr(annotation)
10✔
652
    elif hasattr(annotation, "__name__"):
10✔
653
        return annotation.__name__
10✔
654
    else:
UNCOV
655
        return repr(annotation)
×
656

657

658
def _get_module(obj: Any) -> ModuleType:
10✔
659
    """Return the module where obj was defined (normally obj.__module__).
660

661
    NOTE: this function defines a special case when using PyCharm and the file
662
    defining the object is "Run in Python Console". In this case, the pydevd runner
663
    renames the '__main__' module to 'pydev_umd', and so we need to access that
664
    module instead. This behaviour can be disabled by setting RENAME_MAIN_TO_PYDEV_UMD
665
    to False.
666
    """
667
    module_name = obj.__module__
10✔
668
    module = sys.modules[module_name]
10✔
669

670
    if (
10✔
671
        module_name != "__main__"
672
        or not RENAME_MAIN_TO_PYDEV_UMD
673
        or _PYDEV_UMD_NAME not in sys.modules
674
    ):
675
        return module
10✔
676

677
    # Get a function/class name to check whether it is defined in the module
UNCOV
678
    if isinstance(obj, (FunctionType, type)):
×
UNCOV
679
        name = obj.__name__
×
680
    else:
681
        # For any other type of object, be conservative and just return the module
UNCOV
682
        return module
×
683

UNCOV
684
    if name in vars(module):
×
UNCOV
685
        return module
×
686
    else:
UNCOV
687
        return sys.modules[_PYDEV_UMD_NAME]
×
688

689

690
def _debug(msg: str) -> None:
10✔
691
    """Display a debugging message.
692

693
    Do nothing if DEBUG_CONTRACTS is False.
694
    """
695
    if not DEBUG_CONTRACTS:
10✔
696
        return
10✔
697
    logging.basicConfig(format="[%(levelname)s] %(message)s", level=logging.DEBUG)
10✔
698
    logging.debug(msg)
10✔
699

700

701
def _set_invariants(klass: type) -> None:
10✔
702
    """Retrieve and set the representation invariants of this class"""
703
    # Update representation invariants from this class' docstring and those of its superclasses.
704
    rep_invariants: list[tuple[str, CodeType]] = []
10✔
705

706
    # Iterate over all inherited classes except builtins
707
    for cls in reversed(klass.__mro__):
10✔
708
        if "__representation_invariants__" in cls.__dict__:
10✔
709
            rep_invariants.extend(cls.__representation_invariants__)
10✔
710
        elif cls.__module__ != "builtins":
10✔
711
            assertions = parse_assertions(cls, parse_token="Representation Invariant")
10✔
712
            # Try compiling assertions
713
            for assertion in assertions:
10✔
714
                try:
10✔
715
                    compiled = compile(assertion, "<string>", "eval")
10✔
UNCOV
716
                except:
×
UNCOV
717
                    _debug(
×
718
                        f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
719
                    )
UNCOV
720
                    continue
×
721
                rep_invariants.append((assertion, compiled))
10✔
722

723
    setattr(klass, "__representation_invariants__", rep_invariants)
10✔
724

725

726
def validate_invariants(obj: object) -> None:
10✔
727
    """Check that the representation invariants of obj are satisfied."""
728
    klass = obj.__class__
10✔
729
    klass_mod = _get_module(klass)
10✔
730

731
    try:
10✔
732
        _check_invariants(obj, klass, klass_mod.__dict__)
10✔
733
    except PyTAContractError as e:
10✔
734
        raise AssertionError(str(e)) from None
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc