• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ilotoki0804 / fieldenum / 10341035723

11 Aug 2024 04:01PM UTC coverage: 95.187% (+0.05%) from 95.139%
10341035723

push

github

ilotoki0804
Update dependencies

969 of 1018 relevant lines covered (95.19%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.31
/src/fieldenum/_fieldenum.py
1
"""fieldenum core implementation."""
1✔
2

3
from __future__ import annotations
1✔
4

5
import copyreg
1✔
6
import types
1✔
7
import typing
1✔
8
import warnings
1✔
9
from contextlib import suppress
1✔
10

11
from ._utils import NotAllowed, OneTimeSetter, ParamlessSingletonMeta, unpickle
1✔
12
from .exceptions import unreachable
13

14
Base = typing.TypeVar("Base")
1✔
15

16

17
class Variant:
1✔
18
    __slots__ = ("name", "field", "attached", "_slots_names", "_base", "_generics", "_actual", "_defaults", "_kw_only")
1✔
19

20
    def __set_name__(self, owner, name):
1✔
21
        self.name = name
1✔
22

23
    def __get__(self, obj, objtype=None) -> typing.Self:
1✔
24
        if self.attached:
1✔
25
            # This is needed in order to make match statements work.
26
            return self._actual  # type: ignore
1✔
27

28
        return self
×
29

30
    @classmethod
1✔
31
    def kw_only(cls, **named_field):
1✔
32
        self = cls(**named_field)
1✔
33
        self._kw_only = True
1✔
34
        return self
1✔
35

36
    # fieldless variant
37
    @typing.overload
1✔
38
    def __init__(self) -> None: ...
1✔
39

40
    # tuple variant
41
    @typing.overload
1✔
42
    def __init__(self, *tuple_field) -> None: ...
1✔
43

44
    # named variant
45
    @typing.overload
1✔
46
    def __init__(self, **named_field) -> None: ...
1✔
47

48
    def __init__(self, *tuple_field, **named_field) -> None:
1✔
49
        self.attached = False
1✔
50
        self._kw_only = False
1✔
51
        self._defaults = {}
1✔
52
        if tuple_field and named_field:
1✔
53
            raise TypeError("Cannot mix tuple fields and named fields. Use named fields.")
1✔
54
        self.field = (tuple_field, named_field)
1✔
55
        if named_field:
1✔
56
            self._slots_names = tuple(named_field)
1✔
57
        else:
58
            self._slots_names = tuple(f"_{i}" for i in range(len(tuple_field)))
1✔
59

60
    if typing.TYPE_CHECKING:
61
        def dump(self): ...
62

63
    def check_type(self, field, value, /):
1✔
64
        """Should raise error when type is mismatched."""
65

66
        if field in (typing.Any, typing.Self):
1✔
67
            return
1✔
68

69
        if type(field) in (
1✔
70
            typing.TypeAlias, types.GenericAlias, typing.TypeVar,  # typing-only things
71
            getattr(typing, "TypeAliasType", None),  # type aliases (added in python 3.12)
72
            str,  # possibly type alias
73
        ):
74
            return
1✔
75

76
        try:
1✔
77
            if isinstance(value, field):
1✔
78
                return
1✔
79
        except TypeError:
×
80
            warnings.warn(
×
81
                f"`isinstance` raised TypeError which mean type of field({field!r}, type: {type(field).__name__}) is not supported. "
82
                "Contact developer to `check_type` supports it."
83
            )
84
            return
×
85

86
        raise TypeError(f"Type of value is not expected. Expected type: {field!r}, actual type: {type(value)!r} and value: {value!r}")
1✔
87

88
    def with_defaults(self, **defaults) -> typing.Self:
1✔
89
        _, named_field = self.field
1✔
90
        if not named_field:
1✔
91
            raise TypeError("Only named variants can have defaults.")
1✔
92

93
        self._defaults = defaults
1✔
94
        return self
1✔
95

96
    def attach(
1✔
97
        self,
98
        cls,
99
        /,
100
        *,
101
        eq: bool,
102
        build_hash: bool,
103
        frozen: bool,
104
        runtime_check: bool,
105
    ) -> None | typing.Self:
106
        if self.attached:
1✔
107
            raise TypeError(f"This variants already attached to {self._base.__name__!r}.")
×
108

109
        self._base = cls
1✔
110
        tuple_field, named_field = self.field
1✔
111
        if not self._kw_only:
1✔
112
            named_field_keys = tuple(named_field)
1✔
113
        item = self
1✔
114

115
        self._actual: ConstructedVariant
1✔
116

117
        # fmt: off
118
        class ConstructedVariant(cls):
1✔
119
            if frozen and not typing.TYPE_CHECKING:
1✔
120
                __slots__ = tuple(f"__original_{name}" for name in item._slots_names)
1✔
121
                for name in item._slots_names:
1✔
122
                    # to prevent potential security risk
123
                    if name.isidentifier():
1✔
124
                        exec(f"{name} = OneTimeSetter()")
1✔
125
                    else:
126
                        unreachable(name)
127
                        OneTimeSetter()  # Show IDEs that OneTimeSetter is used. Not executed at runtime.
×
128
            else:
129
                __slots__ = item._slots_names
1✔
130

131
        if tuple_field:
1✔
132
            class TupleConstructedVariant(ConstructedVariant):
1✔
133
                __name__ = item.name
1✔
134
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
135
                __fields__ = tuple(range(len(tuple_field)))
1✔
136
                __slots__ = ()
1✔
137
                __match_args__ = item._slots_names
1✔
138

139
                if build_hash:
1✔
140
                    __slots__ += ("_hash",)
1✔
141

142
                    if frozen:
1✔
143
                        def __hash__(self) -> int:
1✔
144
                            with suppress(AttributeError):
1✔
145
                                return self._hash
1✔
146

147
                            self._hash = hash(self.dump())
1✔
148
                            return self._hash
1✔
149
                    else:
150
                        __hash__ = None  # type: ignore
1✔
151

152
                if eq:
1✔
153
                    def __eq__(self, other: typing.Self):
1✔
154
                        return type(self) is type(other) and self.dump() == other.dump()
1✔
155

156
                def __repr__(self) -> str:
1✔
157
                    values_repr = ", ".join(repr(getattr(self, f"_{name}" if isinstance(name, int) else name)) for name in self.__fields__)
1✔
158
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
159

160
                @staticmethod
1✔
161
                def _pickle(variant):
1✔
162
                    assert isinstance(variant, ConstructedVariant)
1✔
163
                    return unpickle, (cls, self.name, tuple(getattr(variant, f"_{i}") for i in variant.__fields__))
1✔
164

165
                def dump(self) -> tuple:
1✔
166
                    return tuple(getattr(self, f"_{name}") for name in self.__fields__)
1✔
167

168
                def __init__(self, *args) -> None:
1✔
169
                    if len(tuple_field) != len(args):
1✔
170
                        raise TypeError(f"Expect {len(tuple_field)} field(s), but received {len(args)} argument(s).")
1✔
171

172
                    for name, field, value in zip(item._slots_names, tuple_field, args, strict=True):
1✔
173
                        if runtime_check:
1✔
174
                            getattr(self, "check_type", item.check_type)(field, value)
1✔
175
                        setattr(self, name, value)
1✔
176
            self._actual = TupleConstructedVariant
1✔
177

178
        elif named_field:
1✔
179
            class NamedConstructedVariant(ConstructedVariant):
1✔
180
                __name__ = item.name
1✔
181
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
182
                __fields__ = item._slots_names
1✔
183
                __defaults__ = item._defaults
1✔
184
                __slots__ = ()
1✔
185
                if not item._kw_only:
1✔
186
                    __match_args__ = item._slots_names
1✔
187

188
                if build_hash:
1✔
189
                    __slots__ += ("_hash",)
1✔
190

191
                    if frozen:
1✔
192
                        def __hash__(self) -> int:
1✔
193
                            with suppress(AttributeError):
1✔
194
                                return self._hash
1✔
195

196
                            self._hash = hash(tuple(self.dump().items()))
1✔
197
                            return self._hash
1✔
198
                    else:
199
                        __hash__ = None  # type: ignore
1✔
200

201
                if eq:
1✔
202
                    def __eq__(self, other: typing.Self):
1✔
203
                        return type(self) is type(other) and self.dump() == other.dump()
×
204

205
                @staticmethod
1✔
206
                def _pickle(variant):
1✔
207
                    assert isinstance(variant, ConstructedVariant)
1✔
208
                    return unpickle, (cls, self.name, {name: getattr(variant, name) for name in variant.__fields__})
1✔
209

210
                def dump(self):
1✔
211
                    return {name: getattr(self, name) for name in self.__fields__}
1✔
212

213
                if eq:
1✔
214
                    def __eq__(self, other: typing.Self):
1✔
215
                        return type(self) is type(other) and self.dump() == other.dump()
1✔
216

217
                def __repr__(self) -> str:
1✔
218
                    values_repr = ', '.join(f'{name}={getattr(self, f"_{name}" if isinstance(name, int) else name)!r}' for name in self.__fields__)
1✔
219
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
220

221
                def __init__(self, *args, **kwargs) -> None:
1✔
222
                    if args:
1✔
223
                        if item._kw_only:
1✔
224
                            raise TypeError(f"Variant '{type(self).__qualname__}' is keyword only.")
1✔
225

226
                        if len(args) > len(named_field_keys):
1✔
227
                            raise TypeError(f"{self.__name__} takes {len(named_field_keys)} positional argument(s) but {len(args)} were/was given")
1✔
228

229
                        # a valid use case of zip without strict=True
230
                        for arg, field_name in zip(args, named_field_keys):
1✔
231
                            if field_name in kwargs:
1✔
232
                                raise TypeError(f"Inconsistent input for field '{field_name}': received both positional and keyword values")
1✔
233
                            kwargs[field_name] = arg
1✔
234

235
                    if self.__defaults__:
1✔
236
                        kwargs = self.__defaults__ | kwargs
1✔
237

238
                    if missed_keys := kwargs.keys() ^ named_field.keys():
1✔
239
                        raise TypeError(f"Key mismatch: {missed_keys}")
1✔
240

241
                    for name in named_field:
1✔
242
                        value = kwargs[name]
1✔
243
                        field = named_field[name]
1✔
244

245
                        if runtime_check:
1✔
246
                            getattr(self, "check_type", item.check_type)(field, value)
1✔
247
                        setattr(self, name, value)
1✔
248

249
            self._actual = NamedConstructedVariant
1✔
250

251
        else:
252
            class FieldlessConstructedVariant(ConstructedVariant, metaclass=ParamlessSingletonMeta):
1✔
253
                __name__ = item.name
1✔
254
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
255
                __fields__ = ()
1✔
256
                __slots__ = ()
1✔
257

258
                if build_hash and not frozen:
1✔
259
                    __hash__ = None  # type: ignore
1✔
260
                else:
261
                    def __hash__(self):
1✔
262
                        return hash(id(self))
1✔
263

264
                @staticmethod
1✔
265
                def _pickle(variant):
1✔
266
                    assert isinstance(variant, ConstructedVariant)
1✔
267
                    return unpickle, (cls, self.name, ())
1✔
268

269
                def dump(self):
1✔
270
                    return ()
1✔
271

272
                def __repr__(self) -> str:
1✔
273
                    values_repr = ""
1✔
274
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
275

276
                def __init__(self) -> None:
1✔
277
                    pass
1✔
278
            self._actual = FieldlessConstructedVariant
1✔
279
        # fmt: on
280

281
        copyreg.pickle(self._actual, self._actual._pickle)
1✔
282
        self.attached = True
1✔
283

284
    def __call__(self, *args, **kwargs):
1✔
285
        return self._actual(*args, **kwargs)
×
286

287

288
# MARK: UnitDescriptor
289

290

291
class UnitDescriptor:
1✔
292
    __slots__ = ("name",)
1✔
293
    __fields__ = ()
1✔
294

295
    def __init__(self, name: str | None = None):
1✔
296
        self.name = name
1✔
297

298
    def __set_name__(self, owner, name):
1✔
299
        setattr(owner, name, UnitDescriptor(name))
1✔
300

301
    @typing.overload
1✔
302
    def __get__(self, obj, objtype: type[Base] = ...) -> Base: ...  # type: ignore
1✔
303

304
    @typing.overload
1✔
305
    def __get__(self, obj, objtype: None = ...) -> typing.Self: ...
1✔
306

307
    def __get__(self, obj, objtype: type[Base] | None = None) -> Base | typing.Self:
1✔
308
        return self
×
309

310
    def attach(
1✔
311
        self,
312
        cls,
313
        /,
314
        *,
315
        eq: bool,  # not needed since nothing to check equality
316
        build_hash: bool,
317
        frozen: bool,
318
        runtime_check: bool,  # nothing to check
319
    ):
320
        if self.name is None:
1✔
321
            raise TypeError("`self.name` is not set.")
1✔
322

323
        class UnitConstructedVariant(cls, metaclass=ParamlessSingletonMeta):
1✔
324
            __name__ = self.name
1✔
325
            __slots__ = ()
1✔
326
            __fields__ = None  # `None` means it does not require calling for initialize.
1✔
327

328
            if build_hash and not frozen:
1✔
329
                __hash__ = None  # type: ignore # Explicitly disable hash
1✔
330
            else:
331
                def __hash__(self):
1✔
332
                    return hash(id(self))
1✔
333

334
            def dump(self):
1✔
335
                return None
1✔
336

337
            @staticmethod
1✔
338
            def _pickle(variant):
1✔
339
                assert isinstance(variant, UnitConstructedVariant)
1✔
340
                return unpickle, (cls, self.name, None)
1✔
341

342
            def __init__(self):
1✔
343
                pass
1✔
344

345
            def __repr__(self):
1✔
346
                return f"{cls.__name__}.{self.__name__}"
1✔
347

348
        copyreg.pickle(UnitConstructedVariant, UnitConstructedVariant._pickle)
1✔
349

350
        # This will replace Unit to Specialized instance.
351
        setattr(cls, self.name, UnitConstructedVariant())
1✔
352

353

354
Unit = UnitDescriptor()
1✔
355

356

357
def fieldenum(
1✔
358
    cls=None,
359
    /,
360
    *,
361
    eq: bool = True,
362
    frozen: bool = True,
363
    runtime_check: bool = False,
364
):
365
    if cls is None:
1✔
366
        return lambda cls: fieldenum(
1✔
367
            cls,
368
            eq=eq,
369
            frozen=frozen,
370
            runtime_check=runtime_check,
371
        )
372

373
    # Preventing subclassing fieldenums at runtime.
374
    # This also prevent double decoration.
375
    is_final = False
1✔
376
    for base in cls.mro()[1:]:
1✔
377
        with suppress(Exception):
1✔
378
            if base.__final__:
1✔
379
                is_final = True
1✔
380
                break
1✔
381
    if is_final:
1✔
382
        raise TypeError(
1✔
383
            "One of the base classes of fieldenum class is marked as final, "
384
            "which means it does not want to be subclassed and it may be fieldenum class, "
385
            "which should not be subclassed."
386
        )
387

388
    class_attributes = vars(cls)
1✔
389
    has_own_hash = "__hash__" in class_attributes
1✔
390

391
    for attr in class_attributes.values():
1✔
392
        if isinstance(attr, Variant | UnitDescriptor):
1✔
393
            attr.attach(
1✔
394
                cls,
395
                eq=eq,
396
                build_hash=eq and not has_own_hash,
397
                frozen=frozen,
398
                runtime_check=runtime_check,
399
            )
400

401
    with suppress(Exception):
1✔
402
        cls.__init__ = NotAllowed("Base fieldenums cannot be initialized.", name="__init__")
1✔
403

404
    return typing.final(cls)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc