• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

p2p-ld / numpydantic / 17004411285

16 Aug 2025 04:41AM UTC coverage: 97.835%. First build
17004411285

Pull #58

github

web-flow
Merge 3377a2d8c into 1dab66f12
Pull Request #58: Convert `NDArray` to a Protocol

51 of 55 new or added lines in 2 files covered. (92.73%)

1491 of 1524 relevant lines covered (97.83%)

17.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.04
/src/numpydantic/ndarray.py
1
"""
2
Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization
3

4
.. note::
5

6
    This module should *only* have the :class:`.NDArray` class in it, because the
7
    type stub ``ndarray.pyi`` is only created for :class:`.NDArray` . Otherwise,
8
    type checkers will complain about using any helper functions elsewhere -
9
    those all belong in :mod:`numpydantic.schema` .
10

11
    Keeping with nptyping's style, NDArrayMeta is in this module even if it's
12
    excluded from the type stub.
13

14
"""
15

16
import sys
18✔
17
from typing import (
18✔
18
    TYPE_CHECKING,
19
    Any,
20
    Literal,
21
    Protocol,
22
    Tuple,
23
    TypeVar,
24
    Union,
25
    _ProtocolMeta,
26
    get_origin,
27
    runtime_checkable,
28
)
29

30
import numpy as np
18✔
31
from pydantic import GetJsonSchemaHandler
18✔
32
from pydantic_core import core_schema
18✔
33

34
from numpydantic.dtype import DType
18✔
35
from numpydantic.exceptions import InterfaceError
18✔
36
from numpydantic.interface import Interface
18✔
37
from numpydantic.maps import python_to_nptyping
18✔
38
from numpydantic.schema import (
18✔
39
    get_validate_interface,
40
    make_json_schema,
41
)
42
from numpydantic.serialization import jsonize_array
18✔
43
from numpydantic.types import DtypeType, NDArrayType, ShapeType
18✔
44
from numpydantic.validation.dtype import is_union
18✔
45
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
18✔
46
from numpydantic.vendor.nptyping.structure import Structure
18✔
47
from numpydantic.vendor.nptyping.structure_expression import check_type_names
18✔
48
from numpydantic.vendor.nptyping.typing_ import (
18✔
49
    dtype_per_name,
50
)
51

52
if sys.version_info < (3, 11):
18✔
53
    from typing_extensions import Self
10✔
54
else:
55
    from typing import Self
8✔
56

57
if TYPE_CHECKING:  # pragma: no cover
58
    from pydantic._internal._schema_generation_shared import (
59
        CallbackGetCoreSchemaHandler,
60
    )
61

62
    from numpydantic import Shape
63

64

65
def _is_literal_like(item: Any) -> bool:
18✔
66
    """
67
    Changes from nptyping:
68
    - doesn't just ducktype for literal but actually, yno, checks for being literal
69
    """
70
    return get_origin(item) is Literal
18✔
71

72

73
def _get_shape(dtype_candidate: Any) -> "Shape":
18✔
74
    """
75
    Override of base method to use our local definition of shape
76
    """
77
    from numpydantic.validation.shape import Shape
18✔
78

79
    if dtype_candidate is Any or dtype_candidate is Shape:
18✔
80
        shape = Any
18✔
81
    elif issubclass(dtype_candidate, Shape):
18✔
82
        shape = dtype_candidate
18✔
83
    elif _is_literal_like(dtype_candidate):
18✔
84
        shape_expression = dtype_candidate.__args__[0]
18✔
85
        shape = Shape[shape_expression]
18✔
86
    else:
NEW
87
        raise InvalidArgumentsError(
×
88
            f"Unexpected argument '{dtype_candidate}', expecting"
89
            " Shape[<ShapeExpression>]"
90
            " or Literal[<ShapeExpression>]"
91
            " or typing.Any."
92
        )
93
    return shape
18✔
94

95

96
def _get_dtype(dtype_candidate: Any) -> DType:
18✔
97
    """
98
    Override of base _get_dtype method to allow for compound tuple types
99
    """
100
    if dtype_candidate in python_to_nptyping:
18✔
101
        dtype_candidate = python_to_nptyping[dtype_candidate]
18✔
102
    is_dtype = isinstance(dtype_candidate, type) and issubclass(
18✔
103
        dtype_candidate, np.generic
104
    )
105

106
    if dtype_candidate is Any:
18✔
107
        dtype = Any
18✔
108
    elif is_dtype or is_union(dtype_candidate):
18✔
109
        dtype = dtype_candidate
18✔
110
    elif issubclass(dtype_candidate, Structure):  # pragma: no cover
111
        dtype = dtype_candidate
112
        check_type_names(dtype, dtype_per_name)
113
    elif _is_literal_like(dtype_candidate):  # pragma: no cover
114
        structure_expression = dtype_candidate.__args__[0]
115
        dtype = Structure[structure_expression]
116
        check_type_names(dtype, dtype_per_name)
117
    elif isinstance(dtype_candidate, tuple):  # pragma: no cover
118
        dtype = tuple([_get_dtype(dt) for dt in dtype_candidate])
119
    else:
120
        # arbitrary dtype - allow failure elsewhere :)
121
        dtype = dtype_candidate
18✔
122

123
    return dtype
18✔
124

125

126
TShape = TypeVar("TShape")
18✔
127
TDType = TypeVar("TDType")
18✔
128

129

130
class NDArrayMeta(_ProtocolMeta):
18✔
131
    """
132
    Metaclass to provide class-level methods to NDArray protocol
133
    without suggesting they are part of the protocol definition.
134
    """
135

136
    __args__: Tuple[ShapeType, DtypeType] = (Any, Any)
18✔
137

138
    def __call__(cls, val: NDArrayType) -> NDArrayType:
18✔
139
        """Call ndarray as a validator function"""
140
        return get_validate_interface(cls.__args__[0], cls.__args__[1])(val)
18✔
141

142
    def __instancecheck__(self, instance: Any):
18✔
143
        """
144
        Extended type checking that determines whether
145

146
        1) the ``type`` of the given instance is one of those in
147
            :meth:`.Interface.input_types`
148

149
        but also
150

151
        2) it satisfies the constraints set on the :class:`.NDArray` annotation
152

153
        Args:
154
            instance (:class:`typing.Any`): Thing to check!
155

156
        Returns:
157
            bool: ``True`` if matches constraints, ``False`` otherwise.
158
        """
159
        shape, dtype = self.__args__
18✔
160
        try:
18✔
161
            interface_cls = Interface.match(instance, fast=True)
18✔
162
            interface = interface_cls(shape, dtype)
18✔
163
            _ = interface.validate(instance)
18✔
164
            return True
18✔
165
        except InterfaceError:
18✔
166
            return False
18✔
167

168
    def _dtype_to_str(cls, dtype: Any) -> str:
18✔
169
        if dtype is Any:
×
170
            result = "Any"
×
171
        elif issubclass(dtype, Structure):
×
172
            result = str(dtype)
×
173
        elif isinstance(dtype, tuple):
×
174
            result = ", ".join([str(dt) for dt in dtype])
×
175
        else:
176
            result = str(dtype)
×
177
        return result
×
178

179
    def __getitem__(cls, args: Union[type[Any], tuple[type[Any], type[Any]]]):
18✔
180
        if not isinstance(args, tuple) or (isinstance(args, tuple) and len(args) == 1):
18✔
181
            # just shape passed
NEW
182
            shape = args if not isinstance(args, TypeVar) else Any
×
NEW
183
            dtype = Any
×
184
        else:
185
            shape = args[0] if not isinstance(args[0], TypeVar) else Any
18✔
186
            dtype = args[1] if not isinstance(args[0], TypeVar) else Any
18✔
187

188
        shape = _get_shape(shape)
18✔
189
        dtype = _get_dtype(dtype)
18✔
190

191
        return type(cls.__name__, (cls,), {**cls.__dict__, "__args__": (shape, dtype)})
18✔
192

193
    def __get_pydantic_core_schema__(
18✔
194
        cls,
195
        _source_type: "NDArray",
196
        _handler: "CallbackGetCoreSchemaHandler",
197
    ) -> core_schema.CoreSchema:
198
        shape, dtype = _source_type.__args__
18✔
199
        shape: ShapeType
200
        dtype: DtypeType
201

202
        # make core schema for json schema, store it and any model definitions
203
        # so that we can use them when rendering json schema
204
        json_schema = make_json_schema(shape, dtype, _handler)
18✔
205

206
        return core_schema.with_info_plain_validator_function(
18✔
207
            get_validate_interface(shape, dtype),
208
            serialization=core_schema.plain_serializer_function_ser_schema(
209
                jsonize_array, when_used="json", info_arg=True
210
            ),
211
            metadata=json_schema,
212
        )
213

214
    def __get_pydantic_json_schema__(
18✔
215
        cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
216
    ) -> core_schema.JsonSchema:
217
        shape, dtype = cls.__args__
18✔
218
        json_schema = handler(schema["metadata"])
18✔
219
        json_schema = handler.resolve_ref_schema(json_schema)
18✔
220

221
        if (
18✔
222
            not isinstance(dtype, tuple)
223
            and dtype.__module__
224
            not in (
225
                "builtins",
226
                "typing",
227
                "types",
228
            )
229
            and hasattr(dtype, "__name__")
230
        ):
231
            json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
18✔
232

233
        return json_schema
18✔
234

235

236
@runtime_checkable
18✔
237
class NDArray(Protocol[TShape, TDType], metaclass=NDArrayMeta):
18✔
238
    """
239
    Constrained array type allowing npytyping syntax for dtype and shape validation
240
    and serialization.
241

242
    This class is not intended to be instantiable, and support for static type
243
    checking is limited,
244
    it implements the ``__get_pydantic_core_schema__`` method to invoke
245
    the relevant :ref:`interface <Interfaces>` for validation and serialization.
246

247
    It is callable, however, which validates and attempts to coerce input to a
248
    supported array type.
249
    There is no such thing as an "NDArray instance," but one can think of it
250
    as a validating passthrough callable.
251

252
    References:
253
        - https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
254
    """
255

256
    shape = property(np.ndarray.shape)
18✔
257

258
    def __getitem__(self: Any, key: Any) -> Self: ...
18✔
259

260
    def __setitem__(self: Any, key: Any, value: Any) -> Any: ...
18✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc