• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

p2p-ld / numpydantic / 27054202885

06 Jun 2026 05:49AM UTC coverage: 97.114% (-0.7%) from 97.821%
27054202885

Pull #69

github

web-flow
Merge c1c4272d0 into 952a740e0
Pull Request #69: aw shit it's mypy plugin time

376 of 403 new or added lines in 14 files covered. (93.3%)

4 existing lines in 1 file now uncovered.

1918 of 1975 relevant lines covered (97.11%)

6.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.94
/src/numpydantic/interface/interface.py
1
"""
2
Base Interface metaclass
3
"""
4

5
import inspect
7✔
6
import warnings
7✔
7
from abc import ABC, abstractmethod
7✔
8
from functools import lru_cache
7✔
9
from importlib.metadata import PackageNotFoundError, version
7✔
10
from operator import attrgetter
7✔
11
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, Union
7✔
12

13
if TYPE_CHECKING:
7✔
NEW
14
    from numpydantic.interface.typing import InterfaceTyping
×
15

16
import numpy as np
7✔
17
from pydantic import BaseModel, SerializationInfo, ValidationError
7✔
18

19
from numpydantic.exceptions import (
7✔
20
    DtypeError,
21
    MarkMismatchError,
22
    NoMatchError,
23
    ShapeError,
24
    TooManyMatchesError,
25
)
26
from numpydantic.types import DtypeType, NDArrayType, ShapeType
7✔
27
from numpydantic.validation import validate_dtype, validate_shape
7✔
28

29
T = TypeVar("T", bound=NDArrayType)
7✔
30
U = TypeVar("U", bound="JsonDict")
7✔
31
V = TypeVar("V")  # input type
7✔
32
W = TypeVar("W")  # Any type in handle_input
7✔
33

34

35
class InterfaceMark(BaseModel):
7✔
36
    """JSON-able mark to be able to round-trip json dumps"""
37

38
    module: str
7✔
39
    cls: str
7✔
40
    name: str
7✔
41
    version: str
7✔
42

43
    def is_valid(self, cls: type["Interface"], raise_on_error: bool = False) -> bool:
7✔
44
        """
45
        Check that a given interface matches the mark.
46

47
        Args:
48
            cls (Type): Interface type to check
49
            raise_on_error (bool): Raise an ``MarkMismatchError`` when the match
50
                is incorrect
51

52
        Returns:
53
            bool
54

55
        Raises:
56
            :class:`.MarkMismatchError` if requested by ``raise_on_error``
57
            for an invalid match
58
        """
59
        mark = cls.mark_interface()
7✔
60
        valid = self == mark
7✔
61
        if not valid and raise_on_error:
7✔
62
            raise MarkMismatchError(
7✔
63
                "Mismatch between serialized mark and current interface, "
64
                f"Serialized: {self}; current: {cls}"
65
            )
66
        return valid
7✔
67

68
    def match_by_name(self) -> type["Interface"] | None:
7✔
69
        """
70
        Try to find a matching interface by its name, returning it if found,
71
        or None if not found.
72
        """
73
        for i in Interface.interfaces(sort=False):
7✔
74
            if i.name == self.name:
7✔
75
                return i
7✔
76
        return None
7✔
77

78

79
class JsonDict(BaseModel):
7✔
80
    """
81
    Representation of array when dumped with round_trip == True.
82

83
    .. admonition:: Developer's Note
84

85
        Any JsonDict that contains an actual array should be named ``value``
86
        rather than array (or any other name), and nothing but the
87
        array data should be named ``value`` .
88

89
        During JSON serialization, it becomes ambiguous what contains an array
90
        of data vs. an array of metadata. For the moment we would like to
91
        reserve the ability to have lists of metadata, so until we rule that out,
92
        we would like to be able to avoid iterating over every element of an array
93
        in any context parameter transformation like relativizing/absolutizing paths.
94
        To avoid that, it's good to agree on a single value name -- ``value`` --
95
        and avoid using it for anything else.
96

97
    """
98

99
    type: str
7✔
100

101
    @abstractmethod
7✔
102
    def to_array_input(self) -> V:
7✔
103
        """
104
        Convert this roundtrip specifier to the relevant input class
105
        (one of the ``input_types`` of an interface).
106
        """
107

108
    @classmethod
7✔
109
    def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
7✔
110
        """
111
        Check whether a given dictionary matches this JsonDict specification
112

113
        Args:
114
            val (dict): The dictionary to check for validity
115
            raise_on_error (bool): If ``True``, raise the validation error
116
                rather than returning a bool. (default: ``False``)
117

118
        Returns:
119
            bool - true if valid, false if not
120
        """
121
        try:
7✔
122
            _ = cls.model_validate(val)
7✔
123
            return True
7✔
124
        except ValidationError as e:
7✔
125
            if raise_on_error:
7✔
126
                raise e
7✔
127
            return False
7✔
128

129
    @classmethod
7✔
130
    def handle_input(cls: type[U], value: dict | U | W) -> V | W:
7✔
131
        """
132
        Handle input that is the json serialized roundtrip version
133
        (from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``)
134
        converting it to the input format with :meth:`.JsonDict.to_array_input`
135
        or passing it through if not applicable
136
        """
137
        if isinstance(value, dict):
7✔
138
            value = cls(**value).to_array_input()
7✔
139
        elif isinstance(value, cls):
7✔
140
            value = value.to_array_input()
7✔
141
        return value
7✔
142

143
    @staticmethod
7✔
144
    def reshape_input(value: T, shape: tuple[int, ...]) -> T:
7✔
145
        """
146
        If a `reshape` value is present on the array, and the array shape doesn't match,
147
        attempt to reshape it.
148
        """
149
        if value.shape != shape:
7✔
150
            try:
7✔
151
                value = value.reshape(shape)
7✔
152
            except ValueError:
×
153
                warnings.warn(
×
154
                    f"Input data has shape {value.shape}, "
155
                    f"but roundtrip form specifies {shape},"
156
                    f"and {value.shape} can't be cast to {shape}. "
157
                    f"Attempting to proceed with validation without reshaping.",
158
                    stacklevel=1,
159
                )
160
        return value
7✔
161

162

163
class MarkedJson(BaseModel):
7✔
164
    """
165
    Model of JSON dumped with an additional interface mark
166
    with ``model_dump_json({'mark_interface': True})``
167
    """
168

169
    interface: InterfaceMark
7✔
170
    value: list | dict
7✔
171
    """
7✔
172
    Inner value of the array, we don't validate for JsonDict here, 
173
    that should be downstream from us for performance reasons 
174
    """
175

176
    @classmethod
7✔
177
    def try_cast(cls, value: V | dict) -> Union[V, "MarkedJson"]:
7✔
178
        """
179
        Try to cast to MarkedJson if applicable, otherwise return input
180
        """
181
        if isinstance(value, dict) and "interface" in value and "value" in value:
7✔
182
            try:
7✔
183
                value = MarkedJson(**value)
7✔
184
            except ValidationError:
7✔
185
                # fine, just not a MarkedJson dict even if it looks like one
186
                return value
7✔
187
        return value
7✔
188

189

190
class Interface(ABC, Generic[T]):
7✔
191
    """
192
    Abstract parent class for interfaces to different array formats
193
    """
194

195
    input_types: tuple[Any, ...]
5✔
196
    return_type: type[T]
5✔
197
    priority: int = 0
7✔
198
    typing: ClassVar[type["InterfaceTyping"] | None] = None
7✔
199
    """
7✔
200
    Optional static-typing companion class used by the mypy plugin and
201
    the mypy test generator. ``None`` means this interface does not opt
202
    into static constructor inference.
203
    """
204

205
    def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None:
7✔
206
        self.shape = shape
7✔
207
        self.dtype = dtype
7✔
208

209
    def validate(self, array: Any) -> T:
7✔
210
        """
211
        Validate input, returning final array type
212

213
        Calls the methods, in order:
214

215
        * array = :meth:`.deserialize` (array)
216
        * array = :meth:`.before_validation` (array)
217
        * dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
218
            override if eg. the dtype is not contained in ``array.dtype``
219
        * valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches
220
            the one in the NDArray specification. Override if special
221
            validation logic is needed for a given format
222
        * :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity,
223
            raise an exception if it was invalid. Override to implement custom
224
            exceptions or error conditions, or make validation errors conditional.
225
        * array = :meth:`.after_validate_dtype` (array) - hook for additional
226
            validation or array modification mid-validation
227
        * shape = :meth:`.get_shape` (array) - get the shape from the array,
228
            override if eg. the shape is not contained in ``array.shape``
229
        * valid = :meth:`.validate_shape` (shape) - check that the shape matches
230
            the one in the NDArray specification. Override if special validation
231
            logic is needed.
232
        * :meth:`.raise_for_shape` (valid, shape) - after checking shape validity,
233
            raise an exception if it was invalid. You know the deal bc it's the same
234
            as raise for dtype.
235
        * :meth:`.after_validation` - hook after validation for modifying the array
236
            that is set as the model field value
237

238
        Follow the method signatures and return types to override.
239

240
        Implementing an interface subclass largely consists of overriding these methods
241
        as needed.
242

243
        If validation fails, rather than eg. returning ``False``, exceptions will
244
        be raised (to halt the rest of the pydantic validation process).
245
        When using interfaces outside of pydantic, you must catch both
246
        :class:`.DtypeError` and :class:`.ShapeError` (both of which are children
247
        of :class:`.InterfaceError` )
248

249
        Raises:
250
            :class:`.DtypeError`: Dtype of data doesn't match specification
251
            :class:`.ShapeError`: Shape of data doesn't match specification
252

253
        """
254
        array = self.deserialize(array)
7✔
255

256
        array = self.before_validation(array)
7✔
257

258
        dtype = self.get_dtype(array)
7✔
259
        dtype_valid = self.validate_dtype(dtype)
7✔
260
        self.raise_for_dtype(dtype_valid, dtype)
7✔
261
        array = self.after_validate_dtype(array)
7✔
262

263
        shape = self.get_shape(array)
7✔
264
        shape_valid = self.validate_shape(shape)
7✔
265
        self.raise_for_shape(shape_valid, shape)
7✔
266

267
        array = self.after_validation(array)
7✔
268

269
        return array
7✔
270

271
    def deserialize(self, array: Any) -> V | Any:
7✔
272
        """
273
        If given a JSON serialized version of the array,
274
        deserialize it first.
275

276
        If a roundtrip-serialized :class:`.JsonDict`,
277
        pass to :meth:`.JsonDict.handle_input`.
278

279
        If a roundtrip-serialized :class:`.MarkedJson`,
280
        unpack mark, check for validity, warn if not,
281
        and try to continue with validation
282
        """
283
        if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson):
7✔
284
            try:
7✔
285
                marked_array.interface.is_valid(self.__class__, raise_on_error=True)
7✔
286
            except MarkMismatchError as e:
7✔
287
                warnings.warn(
7✔
288
                    str(e) + "\nAttempting to continue validation...", stacklevel=2
289
                )
290
            array = marked_array.value
7✔
291

292
        return self.json_model.handle_input(array)
7✔
293

294
    def before_validation(self, array: Any) -> NDArrayType:
7✔
295
        """
296
        Optional step pre-validation that coerces the input into a type that can be
297
        validated for shape and dtype
298

299
        Default method is a no-op
300
        """
301
        return array
×
302

303
    def get_dtype(self, array: NDArrayType) -> DtypeType:
7✔
304
        """
305
        Get the dtype from the input array.
306
        """
307
        if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
7✔
308
            return self.get_object_dtype(array)
7✔
309
        else:
310
            return array.dtype
7✔
311

312
    def get_object_dtype(self, array: NDArrayType) -> DtypeType:
7✔
313
        """
314
        When an array contains an object, get the dtype of the object contained
315
        by the array.
316

317
        If this method returns `Any`, the dtype validation passes -
318
        used for e.g. empty arrays for which the dtype of the array can't be determined
319
        (since there are no objects).
320
        """
321
        try:
7✔
322
            return type(array.ravel()[0])
7✔
323
        except IndexError:
7✔
324
            return Any
7✔
325

326
    def validate_dtype(self, dtype: DtypeType) -> bool:
7✔
327
        """
328
        Validate the dtype of the given array, returning
329
        ``True`` if valid, ``False`` if not.
330
        """
331
        return validate_dtype(dtype, self.dtype)
7✔
332

333
    def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
7✔
334
        """
335
        After validating, raise an exception if invalid
336
        Raises:
337
            :class:`~numpydantic.exceptions.DtypeError`
338
        """
339
        if not valid:
7✔
340
            raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}")
7✔
341

342
    def after_validate_dtype(self, array: NDArrayType) -> NDArrayType:
7✔
343
        """
344
        Hook to modify array after validating dtype.
345
        Default is a no-op.
346
        """
347
        return array
7✔
348

349
    def get_shape(self, array: NDArrayType) -> tuple[int, ...]:
7✔
350
        """
351
        Get the shape from the array as a tuple of integers
352
        """
353
        return array.shape
7✔
354

355
    def validate_shape(self, shape: tuple[int, ...]) -> bool:
7✔
356
        """
357
        Validate the shape of the given array against the shape
358
        specifier, returning ``True`` if valid, ``False`` if not.
359

360

361
        """
362
        if self.shape is Any:
7✔
363
            return True
7✔
364

365
        return validate_shape(shape, self.shape)
7✔
366

367
    def raise_for_shape(self, valid: bool, shape: tuple[int, ...]) -> None:
7✔
368
        """
369
        Raise a ShapeError if the shape is invalid.
370

371
        Raises:
372
            :class:`~numpydantic.exceptions.ShapeError`
373
        """
374
        if not valid:
7✔
375
            raise ShapeError(
7✔
376
                f"Invalid shape! expected shape {self.shape.prepared_args}, "
377
                f"got shape {shape}"
378
            )
379

380
    def after_validation(self, array: NDArrayType) -> T:
7✔
381
        """
382
        Optional step post-validation that coerces the intermediate array type into the
383
        return type
384

385
        Default method is a no-op
386
        """
387
        return array
7✔
388

389
    @classmethod
7✔
390
    @abstractmethod
7✔
391
    def check(cls, array: Any) -> bool:
7✔
392
        """
393
        Method to check whether a given input applies to this interface
394
        """
395

396
    @classmethod
7✔
397
    @abstractmethod
7✔
398
    def enabled(cls) -> bool:
7✔
399
        """
400
        Check whether this array interface can be used (eg. its dependent packages are
401
        installed, etc.)
402
        """
403

404
    @property
7✔
405
    @abstractmethod
7✔
406
    def name(self) -> str:
7✔
407
        """
408
        Short name for this interface
409
        """
410

411
    @property
7✔
412
    @abstractmethod
7✔
413
    def json_model(self) -> JsonDict:
7✔
414
        """
415
        The :class:`.JsonDict` model used for roundtripping
416
        JSON serialization
417
        """
418

419
    @classmethod
7✔
420
    @abstractmethod
7✔
421
    def to_json(cls, array: type[T], info: SerializationInfo) -> list | JsonDict:
7✔
422
        """
423
        Convert an array of :attr:`.Interface.return_type` to a JSON-compatible format
424
        using base python types
425
        """
426

427
    @classmethod
7✔
428
    def mark_json(cls, array: list | dict) -> dict:
7✔
429
        """
430
        When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
431
        add additional annotations that would allow the serialized array to be
432
        roundtripped.
433

434
        Default is just to add an :class:`.InterfaceMark`
435

436
        Examples:
437

438
            >>> from pprint import pprint
439
            >>> pprint(Interface.mark_json([1.0, 2.0]))
440
            {'interface': {'cls': 'Interface',
441
                           'module': 'numpydantic.interface.interface',
442
                           'version': '1.2.2'},
443
             'value': [1.0, 2.0]}
444
        """
445
        return {"interface": cls.mark_interface(), "value": array}
7✔
446

447
    @classmethod
7✔
448
    def interfaces(
7✔
449
        cls, with_disabled: bool = False, sort: bool = True
450
    ) -> tuple[type["Interface"], ...]:
451
        """
452
        Enabled interface subclasses
453

454
        Args:
455
            with_disabled (bool): If ``True`` , get every known interface.
456
                If ``False`` (default), get only enabled interfaces.
457
            sort (bool): If ``True`` (default), sort interfaces by priority.
458
                If ``False`` , sorted by definition order. Used for recursion:
459
                we only want to sort once at the top level.
460
        """
461
        # get recursively
462
        subclasses = []
7✔
463
        for i in cls.__subclasses__():
7✔
464
            if with_disabled:
7✔
465
                subclasses.append(i)
7✔
466

467
            if i.enabled():
7✔
468
                subclasses.append(i)
7✔
469

470
            subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False))
7✔
471

472
        if sort:
7✔
473
            subclasses = sorted(
7✔
474
                subclasses,
475
                key=attrgetter("priority"),
476
                reverse=True,
477
            )
478

479
        return tuple(subclasses)
7✔
480

481
    @classmethod
7✔
482
    def return_types(cls) -> tuple[NDArrayType, ...]:
7✔
483
        """Return types for all enabled interfaces"""
484
        return tuple([i.return_type for i in cls.interfaces()])
7✔
485

486
    @classmethod
7✔
487
    def input_types(cls) -> tuple[Any, ...]:
7✔
488
        """Input types for all enabled interfaces"""
489
        in_types = []
7✔
490
        for iface in cls.interfaces():
7✔
491
            if isinstance(iface.input_types, (tuple, list)):
7✔
492
                in_types.extend(iface.input_types)
7✔
493
            else:  # pragma: no cover
494
                in_types.append(iface.input_types)
495

496
        return tuple(in_types)
7✔
497

498
    @classmethod
7✔
499
    def match_mark(cls, array: Any) -> type["Interface"] | None:
7✔
500
        """
501
        Match a marked JSON dump of this array to the interface that it indicates.
502

503
        First find an interface that matches by name, and then run its
504
        ``check`` method, because arrays can be dumped with a mark
505
        but without ``round_trip == True`` (and thus can't necessarily
506
        use the same interface that they were dumped with)
507

508
        Returns:
509
            Interface if match found, None otherwise
510
        """
511
        mark = MarkedJson.try_cast(array)
7✔
512
        if not isinstance(mark, MarkedJson):
7✔
513
            return None
7✔
514

515
        interface = mark.interface.match_by_name()
7✔
516
        if interface is not None and interface.check(mark.value):
7✔
517
            return interface
7✔
518
        return None
×
519

520
    @classmethod
7✔
521
    def match(cls, array: Any, fast: bool = False) -> type["Interface"]:
7✔
522
        """
523
        Find the interface that should be used for this array based on its input type
524

525
        First runs the ``check`` method for all interfaces returned by
526
        :meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` ,
527
        and if no match is found then try the numpy interface. This is because
528
        :meth:`.NumpyInterface.check` can be expensive, as we could potentially
529
        try to
530

531
        Args:
532
            fast (bool): if ``False`` , check all interfaces and raise exceptions for
533
              having multiple matching interfaces (default). If ``True`` ,
534
              check each interface (as ordered by its ``priority`` , decreasing),
535
              and return on the first match.
536
        """
537
        # Shortcircuit match if this is a marked json dump
538
        array = MarkedJson.try_cast(array)
7✔
539
        if (match := cls.match_mark(array)) is not None:
7✔
540
            return match
7✔
541
        elif isinstance(array, MarkedJson):
7✔
542
            array = array.value
×
543

544
        # first try and find a non-numpy interface, since the numpy interface
545
        # will try and load the array into memory in its check method
546
        interfaces = cls.interfaces()
7✔
547
        non_np_interfaces = [i for i in interfaces if i.name != "numpy"]
7✔
548
        np_interface = [i for i in interfaces if i.name == "numpy"][0]
7✔
549

550
        if fast:
7✔
551
            matches = []
7✔
552
            for i in non_np_interfaces:
7✔
553
                if i.check(array):
7✔
554
                    return i
7✔
555
        else:
556
            matches = [i for i in non_np_interfaces if i.check(array)]
7✔
557

558
        if len(matches) > 1:
7✔
559
            msg = f"More than one interface matches input {array}:\n"
7✔
560
            msg += "\n".join([f"  - {i}" for i in matches])
7✔
561
            raise TooManyMatchesError(msg)
7✔
562
        elif len(matches) == 0:
7✔
563
            # now try the numpy interface
564
            if np_interface.check(array):
7✔
565
                return np_interface
7✔
566
            else:
567
                raise NoMatchError(f"No matching interfaces found for input {array}")
7✔
568
        else:
569
            return matches[0]
7✔
570

571
    @classmethod
7✔
572
    def match_output(cls, array: Any) -> type["Interface"]:
7✔
573
        """
574
        Find the interface that should be used based on the output type -
575
        in the case that the output type differs from the input type, eg.
576
        the HDF5 interface, match an instantiated array for purposes of
577
        serialization to json, etc.
578
        """
579
        matches = [i for i in cls.interfaces() if isinstance(array, i.return_type)]
7✔
580
        if len(matches) > 1:
7✔
581
            msg = f"More than one interface matches output {array}:\n"
7✔
582
            msg += "\n".join([f"  - {i}" for i in matches])
7✔
583
            raise TooManyMatchesError(msg)
7✔
584
        elif len(matches) == 0:
7✔
585
            raise NoMatchError(f"No matching interfaces found for output {array}")
7✔
586
        else:
587
            return matches[0]
7✔
588

589
    @classmethod
7✔
590
    @lru_cache(maxsize=32)
7✔
591
    def mark_interface(cls) -> InterfaceMark:
7✔
592
        """
593
        Create an interface mark indicating this interface for validation after
594
        JSON serialization with ``round_trip==True``
595
        """
596
        interface_module = inspect.getmodule(cls)
7✔
597
        interface_module = (
7✔
598
            None if interface_module is None else interface_module.__name__
599
        )
600
        try:
7✔
601
            v = (
7✔
602
                None
603
                if interface_module is None
604
                else version(interface_module.split(".")[0])
605
            )
606
        except (
607
            PackageNotFoundError
608
        ):  # pragma: no cover - no tests for missing interface deps
609
            v = None
610

611
        return InterfaceMark(
7✔
612
            module=interface_module, cls=cls.__name__, name=cls.name, version=v
613
        )
614

615

616
class Proxy(ABC):
7✔
617
    """
618
    A proxy class that exposes some non-array data source (like a video) as an array
619
    """
620

621
    @classmethod
7✔
622
    @abstractmethod
7✔
623
    def proxy_for(cls) -> type[Interface]:
7✔
624
        """
625
        Declare the interface that this is a proxy for,
626
        allowing the proxy to be used with the NDArraySchema annotation
627
        with any of the input types that the Interface supports.
628
        """
NEW
629
        raise NotImplementedError()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc