• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

p2p-ld / numpydantic / 18424677238

11 Oct 2025 04:53AM UTC coverage: 97.836%. First build
18424677238

Pull #61

github

web-flow
Merge e9495eedd into 1dab66f12
Pull Request #61: Support leading zero length dimensions

40 of 43 new or added lines in 6 files covered. (93.02%)

1537 of 1571 relevant lines covered (97.84%)

9.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.73
/src/numpydantic/interface/interface.py
1
"""
2
Base Interface metaclass
3
"""
4

5
import inspect
10✔
6
import warnings
10✔
7
from abc import ABC, abstractmethod
10✔
8
from functools import lru_cache
10✔
9
from importlib.metadata import PackageNotFoundError, version
10✔
10
from operator import attrgetter
10✔
11
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
10✔
12

13
import numpy as np
10✔
14
from pydantic import BaseModel, SerializationInfo, ValidationError
10✔
15

16
from numpydantic.exceptions import (
10✔
17
    DtypeError,
18
    MarkMismatchError,
19
    NoMatchError,
20
    ShapeError,
21
    TooManyMatchesError,
22
)
23
from numpydantic.types import DtypeType, NDArrayType, ShapeType
10✔
24
from numpydantic.validation import validate_dtype, validate_shape
10✔
25

26
T = TypeVar("T", bound=NDArrayType)
10✔
27
U = TypeVar("U", bound="JsonDict")
10✔
28
V = TypeVar("V")  # input type
10✔
29
W = TypeVar("W")  # Any type in handle_input
10✔
30

31

32
class InterfaceMark(BaseModel):
10✔
33
    """JSON-able mark to be able to round-trip json dumps"""
34

35
    module: str
10✔
36
    cls: str
10✔
37
    name: str
10✔
38
    version: str
10✔
39

40
    def is_valid(self, cls: Type["Interface"], raise_on_error: bool = False) -> bool:
10✔
41
        """
42
        Check that a given interface matches the mark.
43

44
        Args:
45
            cls (Type): Interface type to check
46
            raise_on_error (bool): Raise an ``MarkMismatchError`` when the match
47
                is incorrect
48

49
        Returns:
50
            bool
51

52
        Raises:
53
            :class:`.MarkMismatchError` if requested by ``raise_on_error``
54
            for an invalid match
55
        """
56
        mark = cls.mark_interface()
10✔
57
        valid = self == mark
10✔
58
        if not valid and raise_on_error:
10✔
59
            raise MarkMismatchError(
10✔
60
                "Mismatch between serialized mark and current interface, "
61
                f"Serialized: {self}; current: {cls}"
62
            )
63
        return valid
10✔
64

65
    def match_by_name(self) -> Optional[Type["Interface"]]:
10✔
66
        """
67
        Try to find a matching interface by its name, returning it if found,
68
        or None if not found.
69
        """
70
        for i in Interface.interfaces(sort=False):
10✔
71
            if i.name == self.name:
10✔
72
                return i
10✔
73
        return None
10✔
74

75

76
class JsonDict(BaseModel):
10✔
77
    """
78
    Representation of array when dumped with round_trip == True.
79

80
    .. admonition:: Developer's Note
81

82
        Any JsonDict that contains an actual array should be named ``value``
83
        rather than array (or any other name), and nothing but the
84
        array data should be named ``value`` .
85

86
        During JSON serialization, it becomes ambiguous what contains an array
87
        of data vs. an array of metadata. For the moment we would like to
88
        reserve the ability to have lists of metadata, so until we rule that out,
89
        we would like to be able to avoid iterating over every element of an array
90
        in any context parameter transformation like relativizing/absolutizing paths.
91
        To avoid that, it's good to agree on a single value name -- ``value`` --
92
        and avoid using it for anything else.
93

94
    """
95

96
    type: str
10✔
97

98
    @abstractmethod
10✔
99
    def to_array_input(self) -> V:
10✔
100
        """
101
        Convert this roundtrip specifier to the relevant input class
102
        (one of the ``input_types`` of an interface).
103
        """
104

105
    @classmethod
10✔
106
    def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
10✔
107
        """
108
        Check whether a given dictionary matches this JsonDict specification
109

110
        Args:
111
            val (dict): The dictionary to check for validity
112
            raise_on_error (bool): If ``True``, raise the validation error
113
                rather than returning a bool. (default: ``False``)
114

115
        Returns:
116
            bool - true if valid, false if not
117
        """
118
        try:
10✔
119
            _ = cls.model_validate(val)
10✔
120
            return True
10✔
121
        except ValidationError as e:
10✔
122
            if raise_on_error:
10✔
123
                raise e
10✔
124
            return False
10✔
125

126
    @classmethod
10✔
127
    def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]:
10✔
128
        """
129
        Handle input that is the json serialized roundtrip version
130
        (from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``)
131
        converting it to the input format with :meth:`.JsonDict.to_array_input`
132
        or passing it through if not applicable
133
        """
134
        if isinstance(value, dict):
10✔
135
            value = cls(**value).to_array_input()
10✔
136
        elif isinstance(value, cls):
10✔
137
            value = value.to_array_input()
10✔
138
        return value
10✔
139

140
    @staticmethod
10✔
141
    def reshape_input(value: T, shape: tuple[int, ...]) -> T:
10✔
142
        """
143
        If a `reshape` value is present on the array, and the array shape doesn't match,
144
        attempt to reshape it.
145
        """
146
        if value.shape != shape:
10✔
147
            try:
10✔
148
                value = value.reshape(shape)
10✔
NEW
149
            except ValueError:
×
NEW
150
                warnings.warn(
×
151
                    f"Input data has shape {value.shape}, "
152
                    f"but roundtrip form specifies {shape},"
153
                    f"and {value.shape} can't be cast to {shape}. "
154
                    f"Attempting to proceed with validation without reshaping.",
155
                    stacklevel=1,
156
                )
157
        return value
10✔
158

159

160
class MarkedJson(BaseModel):
10✔
161
    """
162
    Model of JSON dumped with an additional interface mark
163
    with ``model_dump_json({'mark_interface': True})``
164
    """
165

166
    interface: InterfaceMark
10✔
167
    value: Union[list, dict]
10✔
168
    """
10✔
169
    Inner value of the array, we don't validate for JsonDict here, 
170
    that should be downstream from us for performance reasons 
171
    """
172

173
    @classmethod
10✔
174
    def try_cast(cls, value: Union[V, dict]) -> Union[V, "MarkedJson"]:
10✔
175
        """
176
        Try to cast to MarkedJson if applicable, otherwise return input
177
        """
178
        if isinstance(value, dict) and "interface" in value and "value" in value:
10✔
179
            try:
10✔
180
                value = MarkedJson(**value)
10✔
181
            except ValidationError:
10✔
182
                # fine, just not a MarkedJson dict even if it looks like one
183
                return value
10✔
184
        return value
10✔
185

186

187
class Interface(ABC, Generic[T]):
10✔
188
    """
189
    Abstract parent class for interfaces to different array formats
190
    """
191

192
    input_types: Tuple[Any, ...]
10✔
193
    return_type: Type[T]
10✔
194
    priority: int = 0
10✔
195

196
    def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None:
10✔
197
        self.shape = shape
10✔
198
        self.dtype = dtype
10✔
199

200
    def validate(self, array: Any) -> T:
10✔
201
        """
202
        Validate input, returning final array type
203

204
        Calls the methods, in order:
205

206
        * array = :meth:`.deserialize` (array)
207
        * array = :meth:`.before_validation` (array)
208
        * dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
209
            override if eg. the dtype is not contained in ``array.dtype``
210
        * valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches
211
            the one in the NDArray specification. Override if special
212
            validation logic is needed for a given format
213
        * :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity,
214
            raise an exception if it was invalid. Override to implement custom
215
            exceptions or error conditions, or make validation errors conditional.
216
        * array = :meth:`.after_validate_dtype` (array) - hook for additional
217
            validation or array modification mid-validation
218
        * shape = :meth:`.get_shape` (array) - get the shape from the array,
219
            override if eg. the shape is not contained in ``array.shape``
220
        * valid = :meth:`.validate_shape` (shape) - check that the shape matches
221
            the one in the NDArray specification. Override if special validation
222
            logic is needed.
223
        * :meth:`.raise_for_shape` (valid, shape) - after checking shape validity,
224
            raise an exception if it was invalid. You know the deal bc it's the same
225
            as raise for dtype.
226
        * :meth:`.after_validation` - hook after validation for modifying the array
227
            that is set as the model field value
228

229
        Follow the method signatures and return types to override.
230

231
        Implementing an interface subclass largely consists of overriding these methods
232
        as needed.
233

234
        Raises:
235
            If validation fails, rather than eg. returning ``False``, exceptions will
236
            be raised (to halt the rest of the pydantic validation process).
237
            When using interfaces outside of pydantic, you must catch both
238
            :class:`.DtypeError` and :class:`.ShapeError` (both of which are children
239
            of :class:`.InterfaceError` )
240
        """
241
        array = self.deserialize(array)
10✔
242

243
        array = self.before_validation(array)
10✔
244

245
        dtype = self.get_dtype(array)
10✔
246
        dtype_valid = self.validate_dtype(dtype)
10✔
247
        self.raise_for_dtype(dtype_valid, dtype)
10✔
248
        array = self.after_validate_dtype(array)
10✔
249

250
        shape = self.get_shape(array)
10✔
251
        shape_valid = self.validate_shape(shape)
10✔
252
        self.raise_for_shape(shape_valid, shape)
10✔
253

254
        array = self.after_validation(array)
10✔
255

256
        return array
10✔
257

258
    def deserialize(self, array: Any) -> Union[V, Any]:
10✔
259
        """
260
        If given a JSON serialized version of the array,
261
        deserialize it first.
262

263
        If a roundtrip-serialized :class:`.JsonDict`,
264
        pass to :meth:`.JsonDict.handle_input`.
265

266
        If a roundtrip-serialized :class:`.MarkedJson`,
267
        unpack mark, check for validity, warn if not,
268
        and try to continue with validation
269
        """
270
        if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson):
10✔
271
            try:
10✔
272
                marked_array.interface.is_valid(self.__class__, raise_on_error=True)
10✔
273
            except MarkMismatchError as e:
10✔
274
                warnings.warn(
10✔
275
                    str(e) + "\nAttempting to continue validation...", stacklevel=2
276
                )
277
            array = marked_array.value
10✔
278

279
        return self.json_model.handle_input(array)
10✔
280

281
    def before_validation(self, array: Any) -> NDArrayType:
10✔
282
        """
283
        Optional step pre-validation that coerces the input into a type that can be
284
        validated for shape and dtype
285

286
        Default method is a no-op
287
        """
288
        return array
×
289

290
    def get_dtype(self, array: NDArrayType) -> DtypeType:
10✔
291
        """
292
        Get the dtype from the input array.
293
        """
294
        if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
10✔
295
            return self.get_object_dtype(array)
10✔
296
        else:
297
            return array.dtype
10✔
298

299
    def get_object_dtype(self, array: NDArrayType) -> DtypeType:
10✔
300
        """
301
        When an array contains an object, get the dtype of the object contained
302
        by the array.
303

304
        If this method returns `Any`, the dtype validation passes -
305
        used for e.g. empty arrays for which the dtype of the array can't be determined
306
        (since there are no objects).
307
        """
308
        try:
10✔
309
            return type(array.ravel()[0])
10✔
310
        except IndexError:
10✔
311
            return Any
10✔
312

313
    def validate_dtype(self, dtype: DtypeType) -> bool:
10✔
314
        """
315
        Validate the dtype of the given array, returning
316
        ``True`` if valid, ``False`` if not.
317
        """
318
        return validate_dtype(dtype, self.dtype)
10✔
319

320
    def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
10✔
321
        """
322
        After validating, raise an exception if invalid
323
        Raises:
324
            :class:`~numpydantic.exceptions.DtypeError`
325
        """
326
        if not valid:
10✔
327
            raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}")
10✔
328

329
    def after_validate_dtype(self, array: NDArrayType) -> NDArrayType:
10✔
330
        """
331
        Hook to modify array after validating dtype.
332
        Default is a no-op.
333
        """
334
        return array
10✔
335

336
    def get_shape(self, array: NDArrayType) -> Tuple[int, ...]:
10✔
337
        """
338
        Get the shape from the array as a tuple of integers
339
        """
340
        return array.shape
10✔
341

342
    def validate_shape(self, shape: Tuple[int, ...]) -> bool:
10✔
343
        """
344
        Validate the shape of the given array against the shape
345
        specifier, returning ``True`` if valid, ``False`` if not.
346

347

348
        """
349
        if self.shape is Any:
10✔
350
            return True
10✔
351

352
        return validate_shape(shape, self.shape)
10✔
353

354
    def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
10✔
355
        """
356
        Raise a ShapeError if the shape is invalid.
357

358
        Raises:
359
            :class:`~numpydantic.exceptions.ShapeError`
360
        """
361
        if not valid:
10✔
362
            raise ShapeError(
10✔
363
                f"Invalid shape! expected shape {self.shape.prepared_args}, "
364
                f"got shape {shape}"
365
            )
366

367
    def after_validation(self, array: NDArrayType) -> T:
10✔
368
        """
369
        Optional step post-validation that coerces the intermediate array type into the
370
        return type
371

372
        Default method is a no-op
373
        """
374
        return array
10✔
375

376
    @classmethod
10✔
377
    @abstractmethod
10✔
378
    def check(cls, array: Any) -> bool:
10✔
379
        """
380
        Method to check whether a given input applies to this interface
381
        """
382

383
    @classmethod
10✔
384
    @abstractmethod
10✔
385
    def enabled(cls) -> bool:
10✔
386
        """
387
        Check whether this array interface can be used (eg. its dependent packages are
388
        installed, etc.)
389
        """
390

391
    @property
10✔
392
    @abstractmethod
10✔
393
    def name(self) -> str:
10✔
394
        """
395
        Short name for this interface
396
        """
397

398
    @property
10✔
399
    @abstractmethod
10✔
400
    def json_model(self) -> JsonDict:
10✔
401
        """
402
        The :class:`.JsonDict` model used for roundtripping
403
        JSON serialization
404
        """
405

406
    @classmethod
10✔
407
    @abstractmethod
10✔
408
    def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
10✔
409
        """
410
        Convert an array of :attr:`.return_type` to a JSON-compatible format using
411
        base python types
412
        """
413

414
    @classmethod
10✔
415
    def mark_json(cls, array: Union[list, dict]) -> dict:
10✔
416
        """
417
        When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
418
        add additional annotations that would allow the serialized array to be
419
        roundtripped.
420

421
        Default is just to add an :class:`.InterfaceMark`
422

423
        Examples:
424

425
            >>> from pprint import pprint
426
            >>> pprint(Interface.mark_json([1.0, 2.0]))
427
            {'interface': {'cls': 'Interface',
428
                           'module': 'numpydantic.interface.interface',
429
                           'version': '1.2.2'},
430
             'value': [1.0, 2.0]}
431
        """
432
        return {"interface": cls.mark_interface(), "value": array}
10✔
433

434
    @classmethod
10✔
435
    def interfaces(
10✔
436
        cls, with_disabled: bool = False, sort: bool = True
437
    ) -> Tuple[Type["Interface"], ...]:
438
        """
439
        Enabled interface subclasses
440

441
        Args:
442
            with_disabled (bool): If ``True`` , get every known interface.
443
                If ``False`` (default), get only enabled interfaces.
444
            sort (bool): If ``True`` (default), sort interfaces by priority.
445
                If ``False`` , sorted by definition order. Used for recursion:
446
                we only want to sort once at the top level.
447
        """
448
        # get recursively
449
        subclasses = []
10✔
450
        for i in cls.__subclasses__():
10✔
451
            if with_disabled:
10✔
452
                subclasses.append(i)
10✔
453

454
            if i.enabled():
10✔
455
                subclasses.append(i)
10✔
456

457
            subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False))
10✔
458

459
        if sort:
10✔
460
            subclasses = sorted(
10✔
461
                subclasses,
462
                key=attrgetter("priority"),
463
                reverse=True,
464
            )
465

466
        return tuple(subclasses)
10✔
467

468
    @classmethod
10✔
469
    def return_types(cls) -> Tuple[NDArrayType, ...]:
10✔
470
        """Return types for all enabled interfaces"""
471
        return tuple([i.return_type for i in cls.interfaces()])
10✔
472

473
    @classmethod
10✔
474
    def input_types(cls) -> Tuple[Any, ...]:
10✔
475
        """Input types for all enabled interfaces"""
476
        in_types = []
10✔
477
        for iface in cls.interfaces():
10✔
478
            if isinstance(iface.input_types, (tuple, list)):
10✔
479
                in_types.extend(iface.input_types)
10✔
480
            else:  # pragma: no cover
481
                in_types.append(iface.input_types)
482

483
        return tuple(in_types)
10✔
484

485
    @classmethod
10✔
486
    def match_mark(cls, array: Any) -> Optional[Type["Interface"]]:
10✔
487
        """
488
        Match a marked JSON dump of this array to the interface that it indicates.
489

490
        First find an interface that matches by name, and then run its
491
        ``check`` method, because arrays can be dumped with a mark
492
        but without ``round_trip == True`` (and thus can't necessarily
493
        use the same interface that they were dumped with)
494

495
        Returns:
496
            Interface if match found, None otherwise
497
        """
498
        mark = MarkedJson.try_cast(array)
10✔
499
        if not isinstance(mark, MarkedJson):
10✔
500
            return None
10✔
501

502
        interface = mark.interface.match_by_name()
10✔
503
        if interface is not None and interface.check(mark.value):
10✔
504
            return interface
10✔
505
        return None
×
506

507
    @classmethod
10✔
508
    def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
10✔
509
        """
510
        Find the interface that should be used for this array based on its input type
511

512
        First runs the ``check`` method for all interfaces returned by
513
        :meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` ,
514
        and if no match is found then try the numpy interface. This is because
515
        :meth:`.NumpyInterface.check` can be expensive, as we could potentially
516
        try to
517

518
        Args:
519
            fast (bool): if ``False`` , check all interfaces and raise exceptions for
520
              having multiple matching interfaces (default). If ``True`` ,
521
              check each interface (as ordered by its ``priority`` , decreasing),
522
              and return on the first match.
523
        """
524
        # Shortcircuit match if this is a marked json dump
525
        array = MarkedJson.try_cast(array)
10✔
526
        if (match := cls.match_mark(array)) is not None:
10✔
527
            return match
10✔
528
        elif isinstance(array, MarkedJson):
10✔
529
            array = array.value
×
530

531
        # first try and find a non-numpy interface, since the numpy interface
532
        # will try and load the array into memory in its check method
533
        interfaces = cls.interfaces()
10✔
534
        non_np_interfaces = [i for i in interfaces if i.name != "numpy"]
10✔
535
        np_interface = [i for i in interfaces if i.name == "numpy"][0]
10✔
536

537
        if fast:
10✔
538
            matches = []
10✔
539
            for i in non_np_interfaces:
10✔
540
                if i.check(array):
10✔
541
                    return i
10✔
542
        else:
543
            matches = [i for i in non_np_interfaces if i.check(array)]
10✔
544

545
        if len(matches) > 1:
10✔
546
            msg = f"More than one interface matches input {array}:\n"
10✔
547
            msg += "\n".join([f"  - {i}" for i in matches])
10✔
548
            raise TooManyMatchesError(msg)
10✔
549
        elif len(matches) == 0:
10✔
550
            # now try the numpy interface
551
            if np_interface.check(array):
10✔
552
                return np_interface
10✔
553
            else:
554
                raise NoMatchError(f"No matching interfaces found for input {array}")
10✔
555
        else:
556
            return matches[0]
10✔
557

558
    @classmethod
10✔
559
    def match_output(cls, array: Any) -> Type["Interface"]:
10✔
560
        """
561
        Find the interface that should be used based on the output type -
562
        in the case that the output type differs from the input type, eg.
563
        the HDF5 interface, match an instantiated array for purposes of
564
        serialization to json, etc.
565
        """
566
        matches = [i for i in cls.interfaces() if isinstance(array, i.return_type)]
10✔
567
        if len(matches) > 1:
10✔
568
            msg = f"More than one interface matches output {array}:\n"
10✔
569
            msg += "\n".join([f"  - {i}" for i in matches])
10✔
570
            raise TooManyMatchesError(msg)
10✔
571
        elif len(matches) == 0:
10✔
572
            raise NoMatchError(f"No matching interfaces found for output {array}")
10✔
573
        else:
574
            return matches[0]
10✔
575

576
    @classmethod
10✔
577
    @lru_cache(maxsize=32)
10✔
578
    def mark_interface(cls) -> InterfaceMark:
10✔
579
        """
580
        Create an interface mark indicating this interface for validation after
581
        JSON serialization with ``round_trip==True``
582
        """
583
        interface_module = inspect.getmodule(cls)
10✔
584
        interface_module = (
10✔
585
            None if interface_module is None else interface_module.__name__
586
        )
587
        try:
10✔
588
            v = (
10✔
589
                None
590
                if interface_module is None
591
                else version(interface_module.split(".")[0])
592
            )
593
        except (
594
            PackageNotFoundError
595
        ):  # pragma: no cover - no tests for missing interface deps
596
            v = None
597

598
        return InterfaceMark(
10✔
599
            module=interface_module, cls=cls.__name__, name=cls.name, version=v
600
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc