• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

p2p-ld / numpydantic / 17003619767

16 Aug 2025 03:11AM UTC coverage: 97.365%. First build
17003619767

Pull #58

github

web-flow
Merge 445cf039e into 1dab66f12
Pull Request #58: Convert `NDArray` to a Protocol

44 of 48 new or added lines in 2 files covered. (91.67%)

1515 of 1556 relevant lines covered (97.37%)

9.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.5
/src/numpydantic/ndarray.py
1
"""
2
Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization
3

4
.. note::
5

6
    This module should *only* have the :class:`.NDArray` class in it, because the
7
    type stub ``ndarray.pyi`` is only created for :class:`.NDArray` . Otherwise,
8
    type checkers will complain about using any helper functions elsewhere -
9
    those all belong in :mod:`numpydantic.schema` .
10

11
    Keeping with nptyping's style, NDArrayMeta is in this module even if it's
12
    excluded from the type stub.
13

14
"""
15

16
from typing import (
10✔
17
    TYPE_CHECKING,
18
    Any,
19
    Literal,
20
    Protocol,
21
    Tuple,
22
    TypeVar,
23
    _ProtocolMeta,
24
    get_origin,
25
    runtime_checkable,
26
)
27

28
import numpy as np
10✔
29
from pydantic import GetJsonSchemaHandler
10✔
30
from pydantic_core import core_schema
10✔
31

32
from numpydantic.dtype import DType
10✔
33
from numpydantic.exceptions import InterfaceError
10✔
34
from numpydantic.interface import Interface
10✔
35
from numpydantic.maps import python_to_nptyping
10✔
36
from numpydantic.schema import (
10✔
37
    get_validate_interface,
38
    make_json_schema,
39
)
40
from numpydantic.serialization import jsonize_array
10✔
41
from numpydantic.types import DtypeType, NDArrayType, ShapeType
10✔
42
from numpydantic.validation.dtype import is_union
10✔
43
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
10✔
44
from numpydantic.vendor.nptyping.structure import Structure
10✔
45
from numpydantic.vendor.nptyping.structure_expression import check_type_names
10✔
46
from numpydantic.vendor.nptyping.typing_ import (
10✔
47
    dtype_per_name,
48
)
49

50
if TYPE_CHECKING:  # pragma: no cover
51
    from pydantic._internal._schema_generation_shared import (
52
        CallbackGetCoreSchemaHandler,
53
    )
54

55
    from numpydantic import Shape
56

57

58
def _is_literal_like(item: Any) -> bool:
10✔
59
    """
60
    Changes from nptyping:
61
    - doesn't just ducktype for literal but actually, yno, checks for being literal
62
    """
63
    return get_origin(item) is Literal
10✔
64

65

66
def _get_shape(dtype_candidate: Any) -> "Shape":
10✔
67
    """
68
    Override of base method to use our local definition of shape
69
    """
70
    from numpydantic.validation.shape import Shape
10✔
71

72
    if dtype_candidate is Any or dtype_candidate is Shape:
10✔
73
        shape = Any
10✔
74
    elif issubclass(dtype_candidate, Shape):
10✔
75
        shape = dtype_candidate
10✔
76
    elif _is_literal_like(dtype_candidate):
10✔
77
        shape_expression = dtype_candidate.__args__[0]
10✔
78
        shape = Shape[shape_expression]
10✔
79
    else:
NEW
80
        raise InvalidArgumentsError(
×
81
            f"Unexpected argument '{dtype_candidate}', expecting"
82
            " Shape[<ShapeExpression>]"
83
            " or Literal[<ShapeExpression>]"
84
            " or typing.Any."
85
        )
86
    return shape
10✔
87

88

89
def _get_dtype(dtype_candidate: Any) -> DType:
10✔
90
    """
91
    Override of base _get_dtype method to allow for compound tuple types
92
    """
93
    if dtype_candidate in python_to_nptyping:
10✔
94
        dtype_candidate = python_to_nptyping[dtype_candidate]
10✔
95
    is_dtype = isinstance(dtype_candidate, type) and issubclass(
10✔
96
        dtype_candidate, np.generic
97
    )
98

99
    if dtype_candidate is Any:
10✔
100
        dtype = Any
10✔
101
    elif is_dtype or is_union(dtype_candidate):
10✔
102
        dtype = dtype_candidate
10✔
103
    elif issubclass(dtype_candidate, Structure):  # pragma: no cover
104
        dtype = dtype_candidate
105
        check_type_names(dtype, dtype_per_name)
106
    elif _is_literal_like(dtype_candidate):  # pragma: no cover
107
        structure_expression = dtype_candidate.__args__[0]
108
        dtype = Structure[structure_expression]
109
        check_type_names(dtype, dtype_per_name)
110
    elif isinstance(dtype_candidate, tuple):  # pragma: no cover
111
        dtype = tuple([_get_dtype(dt) for dt in dtype_candidate])
112
    else:
113
        # arbitrary dtype - allow failure elsewhere :)
114
        dtype = dtype_candidate
10✔
115

116
    return dtype
10✔
117

118

119
TShape = TypeVar("TShape", bound=ShapeType)
10✔
120
TDType = TypeVar("TDType", bound=DtypeType)
10✔
121

122

123
class NDArrayMeta(_ProtocolMeta):
10✔
124
    """
125
    Metaclass to provide class-level methods to NDArray protocol
126
    without suggesting they are part of the protocol definition.
127
    """
128

129
    __args__: Tuple[ShapeType, DtypeType] = (Any, Any)
10✔
130

131
    def __call__(cls, val: NDArrayType) -> NDArrayType:
10✔
132
        """Call ndarray as a validator function"""
133
        return get_validate_interface(cls.__args__[0], cls.__args__[1])(val)
10✔
134

135
    def __instancecheck__(self, instance: Any):
10✔
136
        """
137
        Extended type checking that determines whether
138

139
        1) the ``type`` of the given instance is one of those in
140
            :meth:`.Interface.input_types`
141

142
        but also
143

144
        2) it satisfies the constraints set on the :class:`.NDArray` annotation
145

146
        Args:
147
            instance (:class:`typing.Any`): Thing to check!
148

149
        Returns:
150
            bool: ``True`` if matches constraints, ``False`` otherwise.
151
        """
152
        shape, dtype = self.__args__
10✔
153
        try:
10✔
154
            interface_cls = Interface.match(instance, fast=True)
10✔
155
            interface = interface_cls(shape, dtype)
10✔
156
            _ = interface.validate(instance)
10✔
157
            return True
10✔
158
        except InterfaceError:
10✔
159
            return False
10✔
160

161
    def _dtype_to_str(cls, dtype: Any) -> str:
10✔
162
        if dtype is Any:
×
163
            result = "Any"
×
164
        elif issubclass(dtype, Structure):
×
165
            result = str(dtype)
×
166
        elif isinstance(dtype, tuple):
×
167
            result = ", ".join([str(dt) for dt in dtype])
×
168
        else:
169
            result = str(dtype)
×
170
        return result
×
171

172

173
@runtime_checkable
10✔
174
class NDArray(Protocol[TShape, TDType], metaclass=NDArrayMeta):
10✔
175
    """
176
    Constrained array type allowing npytyping syntax for dtype and shape validation
177
    and serialization.
178

179
    This class is not intended to be instantiable, and support for static type
180
    checking is limited,
181
    it implements the ``__get_pydantic_core_schema__`` method to invoke
182
    the relevant :ref:`interface <Interfaces>` for validation and serialization.
183

184
    It is callable, however, which validates and attempts to coerce input to a
185
    supported array type.
186
    There is no such thing as an "NDArray instance," but one can think of it
187
    as a validating passthrough callable.
188

189
    References:
190
        - https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
191
    """
192

193
    __args__: Tuple[ShapeType, DtypeType] = (Any, Any)
10✔
194

195
    def __class_getitem__(cls, args: type[Any] | tuple[type[Any], type[Any]]):
10✔
196
        if not isinstance(args, tuple) or (isinstance(args, tuple) and len(args) == 1):
10✔
197
            # just shape passed
NEW
198
            shape = args if not isinstance(args, TypeVar) else Any
×
NEW
199
            dtype = Any
×
200
        else:
201
            shape = args[0] if not isinstance(args[0], TypeVar) else Any
10✔
202
            dtype = args[1] if not isinstance(args[0], TypeVar) else Any
10✔
203

204
        shape = _get_shape(shape)
10✔
205
        dtype = _get_dtype(dtype)
10✔
206

207
        return type(cls.__name__, (cls,), {**cls.__dict__, "__args__": (shape, dtype)})
10✔
208

209
    @classmethod
10✔
210
    def __get_pydantic_core_schema__(
10✔
211
        cls,
212
        _source_type: "NDArray",
213
        _handler: "CallbackGetCoreSchemaHandler",
214
    ) -> core_schema.CoreSchema:
215
        shape, dtype = _source_type.__args__
10✔
216
        shape: ShapeType
217
        dtype: DtypeType
218

219
        # make core schema for json schema, store it and any model definitions
220
        # so that we can use them when rendering json schema
221
        json_schema = make_json_schema(shape, dtype, _handler)
10✔
222

223
        return core_schema.with_info_plain_validator_function(
10✔
224
            get_validate_interface(shape, dtype),
225
            serialization=core_schema.plain_serializer_function_ser_schema(
226
                jsonize_array, when_used="json", info_arg=True
227
            ),
228
            metadata=json_schema,
229
        )
230

231
    @classmethod
10✔
232
    def __get_pydantic_json_schema__(
10✔
233
        cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
234
    ) -> core_schema.JsonSchema:
235
        shape, dtype = cls.__args__
10✔
236
        json_schema = handler(schema["metadata"])
10✔
237
        json_schema = handler.resolve_ref_schema(json_schema)
10✔
238

239
        if (
10✔
240
            not isinstance(dtype, tuple)
241
            and dtype.__module__
242
            not in (
243
                "builtins",
244
                "typing",
245
                "types",
246
            )
247
            and hasattr(dtype, "__name__")
248
        ):
249
            json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
10✔
250

251
        return json_schema
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc