• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

LeanderCS / flask-inputfilter / #419

02 Jul 2025 08:04PM UTC coverage: 94.487% (-1.3%) from 95.792%
#419

Pull #60

coveralls-python

LeanderCS
Move complex logic outside of base InputFilter class
Pull Request #60: Optimize

292 of 328 new or added lines in 108 files covered. (89.02%)

10 existing lines in 2 files now uncovered.

1868 of 1977 relevant lines covered (94.49%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.21
/flask_inputfilter/validators/is_dataclass_validator.py
1
from __future__ import annotations
1✔
2

3
import dataclasses
1✔
4
from typing import Any, ClassVar, Optional, Type, TypeVar, Union, _GenericAlias
1✔
5

6
from flask_inputfilter.exceptions import ValidationError
1✔
7
from flask_inputfilter.models import BaseValidator
1✔
8

9
T = TypeVar("T")
1✔
10

11

12
# TODO: Replace with typing.get_origin when Python 3.7 support is dropped.
13
def get_origin(tp: Any) -> Optional[Type[Any]]:
1✔
14
    """
15
    Get the unsubscripted version of a type.
16

17
    This supports typing types like list, dict, etc. and their
18
    typing_extensions equivalents.
19
    """
20
    if isinstance(tp, _GenericAlias):
1✔
21
        return tp.__origin__
1✔
22
    return None
1✔
23

24

25
# TODO: Replace with typing.get_args when Python 3.7 support is dropped.
26
def get_args(tp: Any) -> tuple[Any, ...]:
1✔
27
    """
28
    Get type arguments with all substitutions performed.
29

30
    For unions, basic types, and special typing forms, returns the type
31
    arguments. For example, for list[int] returns (int,).
32
    """
33
    if isinstance(tp, _GenericAlias):
1✔
34
        return tp.__args__
1✔
35
    return ()
×
36

37

38
class IsDataclassValidator(BaseValidator):
1✔
39
    """
40
    Validates that the provided value conforms to a specific dataclass type.
41

42
    **Parameters:**
43

44
    - **dataclass_type** (*Type[dict]*): The expected dataclass type.
45
    - **error_message** (*Optional[str]*): Custom error message if
46
      validation fails.
47

48
    **Expected Behavior:**
49

50
    Ensures the input is a dictionary and, that all expected keys are present.
51
    Raises a ``ValidationError`` if the structure does not match.
52
    All fields in the dataclass are validated against their types, including
53
    nested dataclasses, lists, and dictionaries.
54

55
    **Example Usage:**
56

57
    .. code-block:: python
58

59
        from dataclasses import dataclass
60

61
        @dataclass
62
        class User:
63
            id: int
64
            name: str
65

66
        class UserInputFilter(InputFilter):
67
            def __init__(self):
68
                super().__init__()
69

70
                self.add('user', validators=[
71
                    IsDataclassValidator(dataclass_type=User)
72
                ])
73
    """
74

75
    __slots__ = ("dataclass_type", "error_message")
1✔
76

77
    _ERROR_TEMPLATES: ClassVar = {
1✔
78
        "not_dict": "The provided value is not a dict instance.",
79
        "not_dataclass": "'{dataclass_type}' is not a valid dataclass.",
80
        "missing_field": "Missing required field '{field_name}' in value "
81
        "'{value}'.",
82
        "type_mismatch": "Field '{field_name}' in value '{value}' is not of "
83
        "type '{expected_type}'.",
84
        "list_type": "Field '{field_name}' in value '{value}' is not a valid "
85
        "list of '{item_type}'.",
86
        "list_item": "Item at index {index} in field '{field_name}' is not "
87
        "of type '{expected_type}'.",
88
        "dict_type": "Field '{field_name}' in value '{value}' is not a valid "
89
        "dict with keys of type '{key_type}' and values of type "
90
        "'{value_type}'.",
91
        "dict_key": "Key '{key}' in field '{field_name}' is not of type "
92
        "'{expected_type}'.",
93
        "dict_value": "Value for key '{key}' in field '{field_name}' is not "
94
        "of type '{expected_type}'.",
95
        "union_mismatch": "Field '{field_name}' in value '{value}' does not "
96
        "match any of the types: {types}.",
97
        "unsupported_type": "Unsupported type '{field_type}' for field "
98
        "'{field_name}'.",
99
    }
100

101
    def __init__(
1✔
102
        self,
103
        dataclass_type: Type[T],
104
        error_message: Optional[str] = None,
105
    ) -> None:
106
        self.dataclass_type = dataclass_type
1✔
107
        self.error_message = error_message
1✔
108

109
        if not dataclasses.is_dataclass(self.dataclass_type):
1✔
110
            raise ValueError(
×
111
                self._format_error(
112
                    "not_dataclass", dataclass_type=self.dataclass_type
113
                )
114
            )
115

116
    def _format_error(self, error_type: str, **kwargs) -> str:
1✔
117
        """Format error message using template or custom message."""
118
        if self.error_message:
1✔
119
            return self.error_message
1✔
120

121
        template = self._ERROR_TEMPLATES.get(error_type, "Validation error")
1✔
122
        return template.format(**kwargs)
1✔
123

124
    def validate(self, value: Any) -> None:
1✔
125
        """Validate that value conforms to the dataclass type."""
126
        self._validate_is_dict(value)
1✔
127

128
        for field in dataclasses.fields(self.dataclass_type):
1✔
129
            self._validate_field(field, value)
1✔
130

131
    def _validate_is_dict(self, value: Any) -> None:
1✔
132
        """Ensure value is a dictionary."""
133
        if not isinstance(value, dict):
1✔
134
            raise ValidationError(self._format_error("not_dict"))
1✔
135

136
    def _validate_field(
1✔
137
        self, field: dataclasses.Field, value: dict[str, Any]
138
    ) -> None:
139
        """Validate a single field of the dataclass."""
140
        field_name = field.name
1✔
141
        field_type = field.type
1✔
142

143
        if field_name not in value:
1✔
144
            if not IsDataclassValidator._has_default(field):
1✔
145
                raise ValidationError(
×
146
                    self._format_error(
147
                        "missing_field", field_name=field_name, value=value
148
                    )
149
                )
150
            return
1✔
151

152
        field_value = value[field_name]
1✔
153
        self._validate_field_type(field_name, field_value, field_type, value)
1✔
154

155
    @staticmethod
1✔
156
    def _has_default(field: dataclasses.Field) -> bool:
157
        """Check if a field has a default value."""
158
        return (
1✔
159
            field.default is not dataclasses.MISSING
160
            or field.default_factory is not dataclasses.MISSING
161
        )
162

163
    def _validate_field_type(
1✔
164
        self,
165
        field_name: str,
166
        field_value: Any,
167
        field_type: Type,
168
        parent_value: dict[str, Any],
169
    ) -> None:
170
        """Validate that a field value matches its expected type."""
171
        origin = get_origin(field_type)
1✔
172

173
        if origin is not None:
1✔
174
            self._validate_generic_type(
1✔
175
                field_name, field_value, field_type, origin, parent_value
176
            )
177
        elif dataclasses.is_dataclass(field_type):
1✔
178
            IsDataclassValidator._validate_nested_dataclass(
1✔
179
                field_value, field_type
180
            )
181
        else:
182
            self._validate_simple_type(
1✔
183
                field_name, field_value, field_type, parent_value
184
            )
185

186
    def _validate_generic_type(
1✔
187
        self,
188
        field_name: str,
189
        field_value: Any,
190
        field_type: Type,
191
        origin: Type,
192
        parent_value: dict[str, Any],
193
    ) -> None:
194
        """Validate generic types like list[T], dict[K, V], Optional[T]."""
195
        args = get_args(field_type)
1✔
196

197
        validators = {
1✔
198
            list: self._validate_list_type,
199
            dict: self._validate_dict_type,
200
            Union: self._validate_union_type,
201
        }
202

203
        validator = validators.get(origin)
1✔
204
        if validator:
1✔
205
            validator(field_name, field_value, args, parent_value)
1✔
206
        else:
207
            raise ValidationError(
×
208
                self._format_error(
209
                    "unsupported_type",
210
                    field_type=field_type,
211
                    field_name=field_name,
212
                )
213
            )
214

215
    def _validate_list_type(
1✔
216
        self,
217
        field_name: str,
218
        field_value: Any,
219
        args: tuple[Type, ...],
220
        parent_value: dict[str, Any],
221
    ) -> None:
222
        """Validate list[T] type."""
223
        if not isinstance(field_value, list):
×
224
            raise ValidationError(
×
225
                self._format_error(
226
                    "list_type",
227
                    field_name=field_name,
228
                    value=parent_value,
229
                    item_type=args[0],
230
                )
231
            )
232

233
        item_type = args[0]
×
234
        for i, item in enumerate(field_value):
×
235
            if not isinstance(item, item_type):
×
236
                raise ValidationError(
×
237
                    self._format_error(
238
                        "list_item",
239
                        index=i,
240
                        field_name=field_name,
241
                        expected_type=item_type,
242
                    )
243
                )
244

245
    def _validate_dict_type(
1✔
246
        self,
247
        field_name: str,
248
        field_value: Any,
249
        args: tuple[Type, ...],
250
        parent_value: dict[str, Any],
251
    ) -> None:
252
        """Validate dict[K, V] type."""
253
        if not isinstance(field_value, dict):
×
254
            raise ValidationError(
×
255
                self._format_error(
256
                    "dict_type",
257
                    field_name=field_name,
258
                    value=parent_value,
259
                    key_type=args[0],
260
                    value_type=args[1],
261
                )
262
            )
263

264
        key_type, value_type = args[0], args[1]
×
265
        for k, v in field_value.items():
×
266
            if not isinstance(k, key_type):
×
267
                raise ValidationError(
×
268
                    self._format_error(
269
                        "dict_key",
270
                        key=k,
271
                        field_name=field_name,
272
                        expected_type=key_type,
273
                    )
274
                )
275
            if not isinstance(v, value_type):
×
276
                raise ValidationError(
×
277
                    self._format_error(
278
                        "dict_value",
279
                        key=k,
280
                        field_name=field_name,
281
                        expected_type=value_type,
282
                    )
283
                )
284

285
    def _validate_union_type(
1✔
286
        self,
287
        field_name: str,
288
        field_value: Any,
289
        args: tuple[Type, ...],
290
        parent_value: dict[str, Any],
291
    ) -> None:
292
        """Validate Union types, particularly Optional[T]."""
293
        if None in args:
1✔
294
            if field_value is None:
×
295
                return
×
296

297
            non_none_types = [t for t in args if t is not None]
×
298
            if len(non_none_types) == 1:
×
299
                expected_type = non_none_types[0]
×
300
                if not isinstance(field_value, expected_type):
×
301
                    raise ValidationError(
×
302
                        self._format_error(
303
                            "type_mismatch",
304
                            field_name=field_name,
305
                            value=parent_value,
306
                            expected_type=expected_type,
307
                        )
308
                    )
309
                return
×
310

311
        if not any(isinstance(field_value, t) for t in args):
1✔
312
            types_str = ", ".join(str(t) for t in args)
×
313
            raise ValidationError(
×
314
                self._format_error(
315
                    "union_mismatch",
316
                    field_name=field_name,
317
                    value=parent_value,
318
                    types=types_str,
319
                )
320
            )
321

322
    @staticmethod
1✔
323
    def _validate_nested_dataclass(field_value: Any, field_type: Type) -> None:
324
        """Validate nested dataclass."""
325
        nested_validator = IsDataclassValidator(field_type)
1✔
326
        nested_validator.validate(field_value)
1✔
327

328
    def _validate_simple_type(
1✔
329
        self,
330
        field_name: str,
331
        field_value: Any,
332
        field_type: Type,
333
        parent_value: dict[str, Any],
334
    ) -> None:
335
        """Validate simple types like int, str, bool, etc."""
336
        if not isinstance(field_value, field_type):
1✔
337
            raise ValidationError(
1✔
338
                self._format_error(
339
                    "type_mismatch",
340
                    field_name=field_name,
341
                    value=parent_value,
342
                    expected_type=field_type,
343
                )
344
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc