• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

p2p-ld / numpydantic / 10382572278

14 Aug 2024 06:29AM UTC coverage: 98.177% (+0.004%) from 98.173%
10382572278

push

github

web-flow
Merge pull request #8 from p2p-ld/constructable

Make NDArray callable as a functional validator

3 of 3 new or added lines in 1 file covered. (100.0%)

808 of 823 relevant lines covered (98.18%)

9.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/src/numpydantic/interface/interface.py
1
"""
2
Base Interface metaclass
3
"""
4

5
from abc import ABC, abstractmethod
10✔
6
from operator import attrgetter
10✔
7
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
10✔
8

9
import numpy as np
10✔
10
from pydantic import SerializationInfo
10✔
11

12
from numpydantic.exceptions import (
10✔
13
    DtypeError,
14
    NoMatchError,
15
    ShapeError,
16
    TooManyMatchesError,
17
)
18
from numpydantic.shape import check_shape
10✔
19
from numpydantic.types import DtypeType, NDArrayType, ShapeType
10✔
20

21
T = TypeVar("T", bound=NDArrayType)
10✔
22

23

24
class Interface(ABC, Generic[T]):
10✔
25
    """
26
    Abstract parent class for interfaces to different array formats
27
    """
28

29
    input_types: Tuple[Any, ...]
10✔
30
    return_type: Type[T]
10✔
31
    priority: int = 0
10✔
32

33
    def __init__(self, shape: ShapeType, dtype: DtypeType) -> None:
10✔
34
        self.shape = shape
10✔
35
        self.dtype = dtype
10✔
36

37
    def validate(self, array: Any) -> T:
10✔
38
        """
39
        Validate input, returning final array type
40

41
        Calls the methods, in order:
42

43
        * array = :meth:`.before_validation` (array)
44
        * dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
45
            override if eg. the dtype is not contained in ``array.dtype``
46
        * valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches
47
            the one in the NDArray specification. Override if special
48
            validation logic is needed for a given format
49
        * :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity,
50
            raise an exception if it was invalid. Override to implement custom
51
            exceptions or error conditions, or make validation errors conditional.
52
        * array = :meth:`.after_validate_dtype` (array) - hook for additional
53
            validation or array modification mid-validation
54
        * shape = :meth:`.get_shape` (array) - get the shape from the array,
55
            override if eg. the shape is not contained in ``array.shape``
56
        * valid = :meth:`.validate_shape` (shape) - check that the shape matches
57
            the one in the NDArray specification. Override if special validation
58
            logic is needed.
59
        * :meth:`.raise_for_shape` (valid, shape) - after checking shape validity,
60
            raise an exception if it was invalid. You know the deal bc it's the same
61
            as raise for dtype.
62
        * :meth:`.after_validation` - hook after validation for modifying the array
63
            that is set as the model field value
64

65
        Follow the method signatures and return types to override.
66

67
        Implementing an interface subclass largely consists of overriding these methods
68
        as needed.
69

70
        Raises:
71
            If validation fails, rather than eg. returning ``False``, exceptions will
72
            be raised (to halt the rest of the pydantic validation process).
73
            When using interfaces outside of pydantic, you must catch both
74
            :class:`.DtypeError` and :class:`.ShapeError` (both of which are children
75
            of :class:`.InterfaceError` )
76
        """
77
        array = self.before_validation(array)
10✔
78

79
        dtype = self.get_dtype(array)
10✔
80
        dtype_valid = self.validate_dtype(dtype)
10✔
81
        self.raise_for_dtype(dtype_valid, dtype)
10✔
82
        array = self.after_validate_dtype(array)
10✔
83

84
        shape = self.get_shape(array)
10✔
85
        shape_valid = self.validate_shape(shape)
10✔
86
        self.raise_for_shape(shape_valid, shape)
10✔
87

88
        array = self.after_validation(array)
10✔
89
        return array
10✔
90

91
    def before_validation(self, array: Any) -> NDArrayType:
10✔
92
        """
93
        Optional step pre-validation that coerces the input into a type that can be
94
        validated for shape and dtype
95

96
        Default method is a no-op
97
        """
98
        return array
10✔
99

100
    def get_dtype(self, array: NDArrayType) -> DtypeType:
10✔
101
        """
102
        Get the dtype from the input array
103
        """
104
        if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
10✔
105
            return self.get_object_dtype(array)
10✔
106
        else:
107
            return array.dtype
10✔
108

109
    def get_object_dtype(self, array: NDArrayType) -> DtypeType:
10✔
110
        """
111
        When an array contains an object, get the dtype of the object contained
112
        by the array.
113
        """
114
        return type(array.ravel()[0])
10✔
115

116
    def validate_dtype(self, dtype: DtypeType) -> bool:
10✔
117
        """
118
        Validate the dtype of the given array, returning
119
        ``True`` if valid, ``False`` if not.
120

121

122
        """
123
        if self.dtype is Any:
10✔
124
            return True
10✔
125

126
        if isinstance(self.dtype, tuple):
10✔
127
            valid = dtype in self.dtype
10✔
128
        elif self.dtype is np.str_:
10✔
129
            valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
10✔
130
        else:
131
            # try to match as any subclass, if self.dtype is a class
132
            try:
10✔
133
                valid = issubclass(dtype, self.dtype)
10✔
134
            except TypeError:
10✔
135
                # expected, if dtype or self.dtype is not a class
136
                valid = dtype == self.dtype
10✔
137

138
        return valid
10✔
139

140
    def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
10✔
141
        """
142
        After validating, raise an exception if invalid
143
        Raises:
144
            :class:`~numpydantic.exceptions.DtypeError`
145
        """
146
        if not valid:
10✔
147
            raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}")
10✔
148

149
    def after_validate_dtype(self, array: NDArrayType) -> NDArrayType:
10✔
150
        """
151
        Hook to modify array after validating dtype.
152
        Default is a no-op.
153
        """
154
        return array
10✔
155

156
    def get_shape(self, array: NDArrayType) -> Tuple[int, ...]:
10✔
157
        """
158
        Get the shape from the array as a tuple of integers
159
        """
160
        return array.shape
10✔
161

162
    def validate_shape(self, shape: Tuple[int, ...]) -> bool:
10✔
163
        """
164
        Validate the shape of the given array against the shape
165
        specifier, returning ``True`` if valid, ``False`` if not.
166

167

168
        """
169
        if self.shape is Any:
10✔
170
            return True
10✔
171

172
        return check_shape(shape, self.shape)
10✔
173

174
    def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
10✔
175
        """
176
        Raise a ShapeError if the shape is invalid.
177

178
        Raises:
179
            :class:`~numpydantic.exceptions.ShapeError`
180
        """
181
        if not valid:
10✔
182
            raise ShapeError(
10✔
183
                f"Invalid shape! expected shape {self.shape.prepared_args}, "
184
                f"got shape {shape}"
185
            )
186

187
    def after_validation(self, array: NDArrayType) -> T:
10✔
188
        """
189
        Optional step post-validation that coerces the intermediate array type into the
190
        return type
191

192
        Default method is a no-op
193
        """
194
        return array
10✔
195

196
    @classmethod
10✔
197
    @abstractmethod
10✔
198
    def check(cls, array: Any) -> bool:
10✔
199
        """
200
        Method to check whether a given input applies to this interface
201
        """
202

203
    @classmethod
10✔
204
    @abstractmethod
10✔
205
    def enabled(cls) -> bool:
10✔
206
        """
207
        Check whether this array interface can be used (eg. its dependent packages are
208
        installed, etc.)
209
        """
210

211
    @classmethod
10✔
212
    def to_json(
10✔
213
        cls, array: Type[T], info: Optional[SerializationInfo] = None
214
    ) -> Union[list, dict]:
215
        """
216
        Convert an array of :attr:`.return_type` to a JSON-compatible format using
217
        base python types
218
        """
219
        if not isinstance(array, np.ndarray):  # pragma: no cover
220
            array = np.array(array)
221
        return array.tolist()
10✔
222

223
    @classmethod
10✔
224
    def interfaces(
10✔
225
        cls, with_disabled: bool = False, sort: bool = True
226
    ) -> Tuple[Type["Interface"], ...]:
227
        """
228
        Enabled interface subclasses
229

230
        Args:
231
            with_disabled (bool): If ``True`` , get every known interface.
232
                If ``False`` (default), get only enabled interfaces.
233
            sort (bool): If ``True`` (default), sort interfaces by priority.
234
                If ``False`` , sorted by definition order. Used for recursion:
235
                we only want to sort once at the top level.
236
        """
237
        # get recursively
238
        subclasses = []
10✔
239
        for i in cls.__subclasses__():
10✔
240
            if with_disabled:
10✔
241
                subclasses.append(i)
10✔
242

243
            if i.enabled():
10✔
244
                subclasses.append(i)
10✔
245

246
            subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False))
10✔
247

248
        if sort:
10✔
249
            subclasses = sorted(
10✔
250
                subclasses,
251
                key=attrgetter("priority"),
252
                reverse=True,
253
            )
254

255
        return tuple(subclasses)
10✔
256

257
    @classmethod
10✔
258
    def return_types(cls) -> Tuple[NDArrayType, ...]:
10✔
259
        """Return types for all enabled interfaces"""
260
        return tuple([i.return_type for i in cls.interfaces()])
10✔
261

262
    @classmethod
10✔
263
    def input_types(cls) -> Tuple[Any, ...]:
10✔
264
        """Input types for all enabled interfaces"""
265
        in_types = []
10✔
266
        for iface in cls.interfaces():
10✔
267
            if isinstance(iface.input_types, (tuple, list)):
10✔
268
                in_types.extend(iface.input_types)
10✔
269
            else:  # pragma: no cover
270
                in_types.append(iface.input_types)
271

272
        return tuple(in_types)
10✔
273

274
    @classmethod
10✔
275
    def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
10✔
276
        """
277
        Find the interface that should be used for this array based on its input type
278

279
        First runs the ``check`` method for all interfaces returned by
280
        :meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` ,
281
        and if no match is found then try the numpy interface. This is because
282
        :meth:`.NumpyInterface.check` can be expensive, as we could potentially
283
        try to
284

285
        Args:
286
            fast (bool): if ``False`` , check all interfaces and raise exceptions for
287
              having multiple matching interfaces (default). If ``True`` ,
288
              check each interface (as ordered by its ``priority`` , decreasing),
289
              and return on the first match.
290
        """
291
        # first try and find a non-numpy interface, since the numpy interface
292
        # will try and load the array into memory in its check method
293
        interfaces = cls.interfaces()
10✔
294
        non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
10✔
295
        np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
10✔
296

297
        if fast:
10✔
298
            matches = []
10✔
299
            for i in non_np_interfaces:
10✔
300
                if i.check(array):
10✔
301
                    return i
10✔
302
        else:
303
            matches = [i for i in non_np_interfaces if i.check(array)]
10✔
304

305
        if len(matches) > 1:
10✔
306
            msg = f"More than one interface matches input {array}:\n"
10✔
307
            msg += "\n".join([f"  - {i}" for i in matches])
10✔
308
            raise TooManyMatchesError(msg)
10✔
309
        elif len(matches) == 0:
10✔
310
            # now try the numpy interface
311
            if np_interface.check(array):
10✔
312
                return np_interface
10✔
313
            else:
314
                raise NoMatchError(f"No matching interfaces found for input {array}")
10✔
315
        else:
316
            return matches[0]
10✔
317

318
    @classmethod
10✔
319
    def match_output(cls, array: Any) -> Type["Interface"]:
10✔
320
        """
321
        Find the interface that should be used based on the output type -
322
        in the case that the output type differs from the input type, eg.
323
        the HDF5 interface, match an instantiated array for purposes of
324
        serialization to json, etc.
325
        """
326
        matches = [i for i in cls.interfaces() if isinstance(array, i.return_type)]
10✔
327
        if len(matches) > 1:
10✔
328
            msg = f"More than one interface matches output {array}:\n"
10✔
329
            msg += "\n".join([f"  - {i}" for i in matches])
10✔
330
            raise TooManyMatchesError(msg)
10✔
331
        elif len(matches) == 0:
10✔
332
            raise NoMatchError(f"No matching interfaces found for output {array}")
10✔
333
        else:
334
            return matches[0]
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc