• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

p2p-ld / numpydantic / 11289820347

11 Oct 2024 09:23AM UTC coverage: 98.351% (-0.4%) from 98.757%
11289820347

Pull #31

github

web-flow
Merge 189a3e791 into 69dbe3955
Pull Request #31: [tests] `numpydantic.testing` - exposing helpers for 3rd-party interface development & combinatoric testing

313 of 317 new or added lines in 10 files covered. (98.74%)

9 existing lines in 3 files now uncovered.

1491 of 1516 relevant lines covered (98.35%)

9.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.56
/src/numpydantic/interface/interface.py
1
"""
2
Base Interface metaclass
3
"""
4

5
import inspect
10✔
6
import warnings
10✔
7
from abc import ABC, abstractmethod
10✔
8
from functools import lru_cache
10✔
9
from importlib.metadata import PackageNotFoundError, version
10✔
10
from operator import attrgetter
10✔
11
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
10✔
12

13
import numpy as np
10✔
14
from pydantic import BaseModel, SerializationInfo, ValidationError
10✔
15

16
from numpydantic.exceptions import (
10✔
17
    DtypeError,
18
    MarkMismatchError,
19
    NoMatchError,
20
    ShapeError,
21
    TooManyMatchesError,
22
)
23
from numpydantic.types import DtypeType, NDArrayType, ShapeType
10✔
24
from numpydantic.validation import validate_dtype, validate_shape
10✔
25

26
T = TypeVar("T", bound=NDArrayType)
10✔
27
U = TypeVar("U", bound="JsonDict")
10✔
28
V = TypeVar("V")  # input type
10✔
29
W = TypeVar("W")  # Any type in handle_input
10✔
30

31

32
class InterfaceMark(BaseModel):
10✔
33
    """JSON-able mark to be able to round-trip json dumps"""
34

35
    module: str
10✔
36
    cls: str
10✔
37
    name: str
10✔
38
    version: str
10✔
39

40
    def is_valid(self, cls: Type["Interface"], raise_on_error: bool = False) -> bool:
10✔
41
        """
42
        Check that a given interface matches the mark.
43

44
        Args:
45
            cls (Type): Interface type to check
46
            raise_on_error (bool): Raise an ``MarkMismatchError`` when the match
47
                is incorrect
48

49
        Returns:
50
            bool
51

52
        Raises:
53
            :class:`.MarkMismatchError` if requested by ``raise_on_error``
54
            for an invalid match
55
        """
56
        mark = cls.mark_interface()
10✔
57
        valid = self == mark
10✔
58
        if not valid and raise_on_error:
10✔
59
            raise MarkMismatchError(
10✔
60
                "Mismatch between serialized mark and current interface, "
61
                f"Serialized: {self}; current: {cls}"
62
            )
63
        return valid
10✔
64

65
    def match_by_name(self) -> Optional[Type["Interface"]]:
10✔
66
        """
67
        Try to find a matching interface by its name, returning it if found,
68
        or None if not found.
69
        """
70
        for i in Interface.interfaces(sort=False):
10✔
71
            if i.name == self.name:
10✔
72
                return i
10✔
73
        return None
10✔
74

75

76
class JsonDict(BaseModel):
10✔
77
    """
78
    Representation of array when dumped with round_trip == True.
79

80
    .. admonition:: Developer's Note
81

82
        Any JsonDict that contains an actual array should be named ``value``
83
        rather than array (or any other name), and nothing but the
84
        array data should be named ``value`` .
85

86
        During JSON serialization, it becomes ambiguous what contains an array
87
        of data vs. an array of metadata. For the moment we would like to
88
        reserve the ability to have lists of metadata, so until we rule that out,
89
        we would like to be able to avoid iterating over every element of an array
90
        in any context parameter transformation like relativizing/absolutizing paths.
91
        To avoid that, it's good to agree on a single value name -- ``value`` --
92
        and avoid using it for anything else.
93

94
    """
95

96
    type: str
10✔
97

98
    @abstractmethod
10✔
99
    def to_array_input(self) -> V:
10✔
100
        """
101
        Convert this roundtrip specifier to the relevant input class
102
        (one of the ``input_types`` of an interface).
103
        """
104

105
    @classmethod
10✔
106
    def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
10✔
107
        """
108
        Check whether a given dictionary matches this JsonDict specification
109

110
        Args:
111
            val (dict): The dictionary to check for validity
112
            raise_on_error (bool): If ``True``, raise the validation error
113
                rather than returning a bool. (default: ``False``)
114

115
        Returns:
116
            bool - true if valid, false if not
117
        """
118
        try:
10✔
119
            _ = cls.model_validate(val)
10✔
120
            return True
10✔
121
        except ValidationError as e:
10✔
122
            if raise_on_error:
10✔
123
                raise e
10✔
124
            return False
10✔
125

126
    @classmethod
10✔
127
    def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]:
10✔
128
        """
129
        Handle input that is the json serialized roundtrip version
130
        (from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``)
131
        converting it to the input format with :meth:`.JsonDict.to_array_input`
132
        or passing it through if not applicable
133
        """
134
        if isinstance(value, dict):
10✔
135
            value = cls(**value).to_array_input()
10✔
136
        elif isinstance(value, cls):
10✔
137
            value = value.to_array_input()
10✔
138
        return value
10✔
139

140

141
class MarkedJson(BaseModel):
10✔
142
    """
143
    Model of JSON dumped with an additional interface mark
144
    with ``model_dump_json({'mark_interface': True})``
145
    """
146

147
    interface: InterfaceMark
10✔
148
    value: Union[list, dict]
10✔
149
    """
6✔
150
    Inner value of the array, we don't validate for JsonDict here, 
151
    that should be downstream from us for performance reasons 
152
    """
153

154
    @classmethod
10✔
155
    def try_cast(cls, value: Union[V, dict]) -> Union[V, "MarkedJson"]:
10✔
156
        """
157
        Try to cast to MarkedJson if applicable, otherwise return input
158
        """
159
        if isinstance(value, dict) and "interface" in value and "value" in value:
10✔
160
            try:
10✔
161
                value = MarkedJson(**value)
10✔
162
            except ValidationError:
10✔
163
                # fine, just not a MarkedJson dict even if it looks like one
164
                return value
10✔
165
        return value
10✔
166

167

168
class Interface(ABC, Generic[T]):
10✔
169
    """
170
    Abstract parent class for interfaces to different array formats
171
    """
172

173
    input_types: Tuple[Any, ...]
10✔
174
    return_type: Type[T]
10✔
175
    priority: int = 0
10✔
176

177
    def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None:
10✔
178
        self.shape = shape
10✔
179
        self.dtype = dtype
10✔
180

181
    def validate(self, array: Any) -> T:
10✔
182
        """
183
        Validate input, returning final array type
184

185
        Calls the methods, in order:
186

187
        * array = :meth:`.deserialize` (array)
188
        * array = :meth:`.before_validation` (array)
189
        * dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
190
            override if eg. the dtype is not contained in ``array.dtype``
191
        * valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches
192
            the one in the NDArray specification. Override if special
193
            validation logic is needed for a given format
194
        * :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity,
195
            raise an exception if it was invalid. Override to implement custom
196
            exceptions or error conditions, or make validation errors conditional.
197
        * array = :meth:`.after_validate_dtype` (array) - hook for additional
198
            validation or array modification mid-validation
199
        * shape = :meth:`.get_shape` (array) - get the shape from the array,
200
            override if eg. the shape is not contained in ``array.shape``
201
        * valid = :meth:`.validate_shape` (shape) - check that the shape matches
202
            the one in the NDArray specification. Override if special validation
203
            logic is needed.
204
        * :meth:`.raise_for_shape` (valid, shape) - after checking shape validity,
205
            raise an exception if it was invalid. You know the deal bc it's the same
206
            as raise for dtype.
207
        * :meth:`.after_validation` - hook after validation for modifying the array
208
            that is set as the model field value
209

210
        Follow the method signatures and return types to override.
211

212
        Implementing an interface subclass largely consists of overriding these methods
213
        as needed.
214

215
        Raises:
216
            If validation fails, rather than eg. returning ``False``, exceptions will
217
            be raised (to halt the rest of the pydantic validation process).
218
            When using interfaces outside of pydantic, you must catch both
219
            :class:`.DtypeError` and :class:`.ShapeError` (both of which are children
220
            of :class:`.InterfaceError` )
221
        """
222
        array = self.deserialize(array)
10✔
223

224
        array = self.before_validation(array)
10✔
225

226
        dtype = self.get_dtype(array)
10✔
227
        dtype_valid = self.validate_dtype(dtype)
10✔
228
        self.raise_for_dtype(dtype_valid, dtype)
10✔
229
        array = self.after_validate_dtype(array)
10✔
230

231
        shape = self.get_shape(array)
10✔
232
        shape_valid = self.validate_shape(shape)
10✔
233
        self.raise_for_shape(shape_valid, shape)
10✔
234

235
        array = self.after_validation(array)
10✔
236

237
        return array
10✔
238

239
    def deserialize(self, array: Any) -> Union[V, Any]:
10✔
240
        """
241
        If given a JSON serialized version of the array,
242
        deserialize it first.
243

244
        If a roundtrip-serialized :class:`.JsonDict`,
245
        pass to :meth:`.JsonDict.handle_input`.
246

247
        If a roundtrip-serialized :class:`.MarkedJson`,
248
        unpack mark, check for validity, warn if not,
249
        and try to continue with validation
250
        """
251
        if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson):
10✔
252
            try:
10✔
253
                marked_array.interface.is_valid(self.__class__, raise_on_error=True)
10✔
254
            except MarkMismatchError as e:
10✔
255
                warnings.warn(
10✔
256
                    str(e) + "\nAttempting to continue validation...", stacklevel=2
257
                )
258
            array = marked_array.value
10✔
259

260
        return self.json_model.handle_input(array)
10✔
261

262
    def before_validation(self, array: Any) -> NDArrayType:
10✔
263
        """
264
        Optional step pre-validation that coerces the input into a type that can be
265
        validated for shape and dtype
266

267
        Default method is a no-op
268
        """
UNCOV
269
        return array
×
270

271
    def get_dtype(self, array: NDArrayType) -> DtypeType:
10✔
272
        """
273
        Get the dtype from the input array
274
        """
275
        if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
10✔
276
            return self.get_object_dtype(array)
10✔
277
        else:
278
            return array.dtype
10✔
279

280
    def get_object_dtype(self, array: NDArrayType) -> DtypeType:
10✔
281
        """
282
        When an array contains an object, get the dtype of the object contained
283
        by the array.
284
        """
285
        return type(array.ravel()[0])
10✔
286

287
    def validate_dtype(self, dtype: DtypeType) -> bool:
10✔
288
        """
289
        Validate the dtype of the given array, returning
290
        ``True`` if valid, ``False`` if not.
291
        """
292
        return validate_dtype(dtype, self.dtype)
10✔
293

294
    def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
10✔
295
        """
296
        After validating, raise an exception if invalid
297
        Raises:
298
            :class:`~numpydantic.exceptions.DtypeError`
299
        """
300
        if not valid:
10✔
301
            raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}")
10✔
302

303
    def after_validate_dtype(self, array: NDArrayType) -> NDArrayType:
10✔
304
        """
305
        Hook to modify array after validating dtype.
306
        Default is a no-op.
307
        """
308
        return array
10✔
309

310
    def get_shape(self, array: NDArrayType) -> Tuple[int, ...]:
10✔
311
        """
312
        Get the shape from the array as a tuple of integers
313
        """
314
        return array.shape
10✔
315

316
    def validate_shape(self, shape: Tuple[int, ...]) -> bool:
10✔
317
        """
318
        Validate the shape of the given array against the shape
319
        specifier, returning ``True`` if valid, ``False`` if not.
320

321

322
        """
323
        if self.shape is Any:
10✔
324
            return True
10✔
325

326
        return validate_shape(shape, self.shape)
10✔
327

328
    def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
10✔
329
        """
330
        Raise a ShapeError if the shape is invalid.
331

332
        Raises:
333
            :class:`~numpydantic.exceptions.ShapeError`
334
        """
335
        if not valid:
10✔
336
            raise ShapeError(
10✔
337
                f"Invalid shape! expected shape {self.shape.prepared_args}, "
338
                f"got shape {shape}"
339
            )
340

341
    def after_validation(self, array: NDArrayType) -> T:
10✔
342
        """
343
        Optional step post-validation that coerces the intermediate array type into the
344
        return type
345

346
        Default method is a no-op
347
        """
348
        return array
10✔
349

350
    @classmethod
10✔
351
    @abstractmethod
10✔
352
    def check(cls, array: Any) -> bool:
10✔
353
        """
354
        Method to check whether a given input applies to this interface
355
        """
356

357
    @classmethod
10✔
358
    @abstractmethod
10✔
359
    def enabled(cls) -> bool:
10✔
360
        """
361
        Check whether this array interface can be used (eg. its dependent packages are
362
        installed, etc.)
363
        """
364

365
    @property
10✔
366
    @abstractmethod
10✔
367
    def name(self) -> str:
10✔
368
        """
369
        Short name for this interface
370
        """
371

372
    @property
10✔
373
    @abstractmethod
10✔
374
    def json_model(self) -> JsonDict:
10✔
375
        """
376
        The :class:`.JsonDict` model used for roundtripping
377
        JSON serialization
378
        """
379

380
    @classmethod
10✔
381
    @abstractmethod
10✔
382
    def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
10✔
383
        """
384
        Convert an array of :attr:`.return_type` to a JSON-compatible format using
385
        base python types
386
        """
387

388
    @classmethod
10✔
389
    def mark_json(cls, array: Union[list, dict]) -> dict:
10✔
390
        """
391
        When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
392
        add additional annotations that would allow the serialized array to be
393
        roundtripped.
394

395
        Default is just to add an :class:`.InterfaceMark`
396

397
        Examples:
398

399
            >>> from pprint import pprint
400
            >>> pprint(Interface.mark_json([1.0, 2.0]))
401
            {'interface': {'cls': 'Interface',
402
                           'module': 'numpydantic.interface.interface',
403
                           'version': '1.2.2'},
404
             'value': [1.0, 2.0]}
405
        """
406
        return {"interface": cls.mark_interface(), "value": array}
10✔
407

408
    @classmethod
10✔
409
    def interfaces(
10✔
410
        cls, with_disabled: bool = False, sort: bool = True
411
    ) -> Tuple[Type["Interface"], ...]:
412
        """
413
        Enabled interface subclasses
414

415
        Args:
416
            with_disabled (bool): If ``True`` , get every known interface.
417
                If ``False`` (default), get only enabled interfaces.
418
            sort (bool): If ``True`` (default), sort interfaces by priority.
419
                If ``False`` , sorted by definition order. Used for recursion:
420
                we only want to sort once at the top level.
421
        """
422
        # get recursively
423
        subclasses = []
10✔
424
        for i in cls.__subclasses__():
10✔
425
            if with_disabled:
10✔
426
                subclasses.append(i)
10✔
427

428
            if i.enabled():
10✔
429
                subclasses.append(i)
10✔
430

431
            subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False))
10✔
432

433
        if sort:
10✔
434
            subclasses = sorted(
10✔
435
                subclasses,
436
                key=attrgetter("priority"),
437
                reverse=True,
438
            )
439

440
        return tuple(subclasses)
10✔
441

442
    @classmethod
10✔
443
    def return_types(cls) -> Tuple[NDArrayType, ...]:
10✔
444
        """Return types for all enabled interfaces"""
445
        return tuple([i.return_type for i in cls.interfaces()])
10✔
446

447
    @classmethod
10✔
448
    def input_types(cls) -> Tuple[Any, ...]:
10✔
449
        """Input types for all enabled interfaces"""
450
        in_types = []
10✔
451
        for iface in cls.interfaces():
10✔
452
            if isinstance(iface.input_types, (tuple, list)):
10✔
453
                in_types.extend(iface.input_types)
10✔
454
            else:  # pragma: no cover
455
                in_types.append(iface.input_types)
456

457
        return tuple(in_types)
10✔
458

459
    @classmethod
10✔
460
    def match_mark(cls, array: Any) -> Optional[Type["Interface"]]:
10✔
461
        """
462
        Match a marked JSON dump of this array to the interface that it indicates.
463

464
        First find an interface that matches by name, and then run its
465
        ``check`` method, because arrays can be dumped with a mark
466
        but without ``round_trip == True`` (and thus can't necessarily
467
        use the same interface that they were dumped with)
468

469
        Returns:
470
            Interface if match found, None otherwise
471
        """
472
        mark = MarkedJson.try_cast(array)
10✔
473
        if not isinstance(mark, MarkedJson):
10✔
474
            return None
10✔
475

476
        interface = mark.interface.match_by_name()
10✔
477
        if interface is not None and interface.check(mark.value):
10✔
478
            return interface
10✔
UNCOV
479
        return None
×
480

481
    @classmethod
10✔
482
    def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
10✔
483
        """
484
        Find the interface that should be used for this array based on its input type
485

486
        First runs the ``check`` method for all interfaces returned by
487
        :meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` ,
488
        and if no match is found then try the numpy interface. This is because
489
        :meth:`.NumpyInterface.check` can be expensive, as we could potentially
490
        try to
491

492
        Args:
493
            fast (bool): if ``False`` , check all interfaces and raise exceptions for
494
              having multiple matching interfaces (default). If ``True`` ,
495
              check each interface (as ordered by its ``priority`` , decreasing),
496
              and return on the first match.
497
        """
498
        # Shortcircuit match if this is a marked json dump
499
        array = MarkedJson.try_cast(array)
10✔
500
        if (match := cls.match_mark(array)) is not None:
10✔
501
            return match
10✔
502
        elif isinstance(array, MarkedJson):
10✔
UNCOV
503
            array = array.value
×
504

505
        # first try and find a non-numpy interface, since the numpy interface
506
        # will try and load the array into memory in its check method
507
        interfaces = cls.interfaces()
10✔
508
        non_np_interfaces = [i for i in interfaces if i.name != "numpy"]
10✔
509
        np_interface = [i for i in interfaces if i.name == "numpy"][0]
10✔
510

511
        if fast:
10✔
512
            matches = []
10✔
513
            for i in non_np_interfaces:
10✔
514
                if i.check(array):
10✔
515
                    return i
10✔
516
        else:
517
            matches = [i for i in non_np_interfaces if i.check(array)]
10✔
518

519
        if len(matches) > 1:
10✔
520
            msg = f"More than one interface matches input {array}:\n"
10✔
521
            msg += "\n".join([f"  - {i}" for i in matches])
10✔
522
            raise TooManyMatchesError(msg)
10✔
523
        elif len(matches) == 0:
10✔
524
            # now try the numpy interface
525
            if np_interface.check(array):
10✔
526
                return np_interface
10✔
527
            else:
528
                raise NoMatchError(f"No matching interfaces found for input {array}")
10✔
529
        else:
530
            return matches[0]
10✔
531

532
    @classmethod
10✔
533
    def match_output(cls, array: Any) -> Type["Interface"]:
10✔
534
        """
535
        Find the interface that should be used based on the output type -
536
        in the case that the output type differs from the input type, eg.
537
        the HDF5 interface, match an instantiated array for purposes of
538
        serialization to json, etc.
539
        """
540
        matches = [i for i in cls.interfaces() if isinstance(array, i.return_type)]
10✔
541
        if len(matches) > 1:
10✔
542
            msg = f"More than one interface matches output {array}:\n"
10✔
543
            msg += "\n".join([f"  - {i}" for i in matches])
10✔
544
            raise TooManyMatchesError(msg)
10✔
545
        elif len(matches) == 0:
10✔
546
            raise NoMatchError(f"No matching interfaces found for output {array}")
10✔
547
        else:
548
            return matches[0]
10✔
549

550
    @classmethod
10✔
551
    @lru_cache(maxsize=32)
10✔
552
    def mark_interface(cls) -> InterfaceMark:
10✔
553
        """
554
        Create an interface mark indicating this interface for validation after
555
        JSON serialization with ``round_trip==True``
556
        """
557
        interface_module = inspect.getmodule(cls)
10✔
558
        interface_module = (
10✔
559
            None if interface_module is None else interface_module.__name__
560
        )
561
        try:
10✔
562
            v = (
10✔
563
                None
564
                if interface_module is None
565
                else version(interface_module.split(".")[0])
566
            )
567
        except (
568
            PackageNotFoundError
569
        ):  # pragma: no cover - no tests for missing interface deps
570
            v = None
571

572
        return InterfaceMark(
10✔
573
            module=interface_module, cls=cls.__name__, name=cls.name, version=v
574
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc