• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ilotoki0804 / fieldenum / 10717501713

05 Sep 2024 09:16AM UTC coverage: 94.812% (-0.1%) from 94.912%
10717501713

push

github

ilotoki0804
Add `kw_only` parameter

5 of 7 new or added lines in 1 file covered. (71.43%)

1 existing line in 1 file now uncovered.

932 of 983 relevant lines covered (94.81%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.6
/src/fieldenum/_fieldenum.py
1
"""fieldenum core implementation."""
1✔
2

3
from __future__ import annotations
1✔
4

5
import copyreg
1✔
6
import typing
1✔
7
from contextlib import suppress
1✔
8

9
from ._utils import NotAllowed, OneTimeSetter, ParamlessSingletonMeta, unpickle
1✔
10
from .exceptions import unreachable
11

12
T = typing.TypeVar("T")
1✔
13

14

15
class Variant:
1✔
16
    __slots__ = ("name", "field", "attached", "_slots_names", "_base", "_generics", "_actual", "_defaults", "_default_factories", "_kw_only")
1✔
17

18
    def __set_name__(self, owner, name):
1✔
19
        self.name = name
1✔
20

21
    def __get__(self, obj, objtype=None) -> typing.Self:
1✔
22
        if self.attached:
1✔
23
            # This is needed in order to make match statements work.
24
            return self._actual  # type: ignore
1✔
25

26
        return self
×
27

28
    def kw_only(self) -> typing.Self:
1✔
29
        self._kw_only = True
1✔
30
        return self
1✔
31

32
    # fieldless variant
33
    @typing.overload
1✔
34
    def __init__(self) -> None: ...
1✔
35

36
    # tuple variant
37
    @typing.overload
1✔
38
    def __init__(self, *tuple_field) -> None: ...
1✔
39

40
    # named variant
41
    @typing.overload
1✔
42
    def __init__(self, **named_field) -> None: ...
1✔
43

44
    def __init__(self, *tuple_field, **named_field) -> None:
1✔
45
        self.attached = False
1✔
46
        self._kw_only = False
1✔
47
        self._defaults = {}
1✔
48
        self._default_factories = {}
1✔
49
        if tuple_field and named_field:
1✔
50
            raise TypeError("Cannot mix tuple fields and named fields. Use named fields.")
1✔
51
        self.field = (tuple_field, named_field)
1✔
52
        if named_field:
1✔
53
            self._slots_names = tuple(named_field)
1✔
54
        else:
55
            self._slots_names = tuple(f"_{i}" for i in range(len(tuple_field)))
1✔
56

57
    if typing.TYPE_CHECKING:
58
        def dump(self): ...
59

60
    def default(self, **defaults) -> typing.Self:
1✔
61
        _, named_field = self.field
1✔
62
        if not named_field:
1✔
63
            raise TypeError("Only named variants can have defaults.")
1✔
64

65
        self._defaults = defaults
1✔
66
        return self
1✔
67

68
    def default_factory(self, **default_factories) -> typing.Self:
1✔
69
        _, named_field = self.field
1✔
70
        if not named_field:
1✔
71
            raise TypeError("Only named variants can have defaults.")
×
72

73
        self._default_factories = default_factories
1✔
74
        return self
1✔
75

76
    def attach(
1✔
77
        self,
78
        cls,
79
        /,
80
        *,
81
        eq: bool,
82
        build_hash: bool,
83
        frozen: bool,
84
    ) -> None | typing.Self:
85
        if self.attached:
1✔
86
            raise TypeError(f"This variants already attached to {self._base.__name__!r}.")
×
87

88
        self._base = cls
1✔
89
        tuple_field, named_field = self.field
1✔
90
        if not self._kw_only:
1✔
91
            named_field_keys = tuple(named_field)
1✔
92
        item = self
1✔
93

94
        self._actual: ConstructedVariant
1✔
95

96
        # fmt: off
97
        class ConstructedVariant(cls):
1✔
98
            if frozen and not typing.TYPE_CHECKING:
1✔
99
                __slots__ = tuple(f"__original_{name}" for name in item._slots_names)
1✔
100
                for name in item._slots_names:
1✔
101
                    # to prevent potential security risk
102
                    if name.isidentifier():
1✔
103
                        exec(f"{name} = OneTimeSetter()")
1✔
104
                    else:
105
                        unreachable(name)
106
                        OneTimeSetter()  # Show IDEs that OneTimeSetter is used. Not executed at runtime.
×
107
            else:
108
                __slots__ = item._slots_names
1✔
109

110
        if tuple_field:
1✔
111
            class TupleConstructedVariant(ConstructedVariant):
1✔
112
                __name__ = item.name
1✔
113
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
114
                __fields__ = tuple(range(len(tuple_field)))
1✔
115
                __slots__ = ()
1✔
116
                __match_args__ = item._slots_names
1✔
117

118
                if build_hash:
1✔
119
                    __slots__ += ("_hash",)
1✔
120

121
                    if frozen:
1✔
122
                        def __hash__(self) -> int:
1✔
123
                            with suppress(AttributeError):
1✔
124
                                return self._hash
1✔
125

126
                            self._hash = hash(self.dump())
1✔
127
                            return self._hash
1✔
128
                    else:
129
                        __hash__ = None  # type: ignore
1✔
130

131
                if eq:
1✔
132
                    def __eq__(self, other: typing.Self):
1✔
133
                        return type(self) is type(other) and self.dump() == other.dump()
1✔
134

135
                def __repr__(self) -> str:
1✔
136
                    values_repr = ", ".join(repr(getattr(self, f"_{name}" if isinstance(name, int) else name)) for name in self.__fields__)
1✔
137
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
138

139
                @staticmethod
1✔
140
                def _pickle(variant):
1✔
141
                    assert isinstance(variant, ConstructedVariant)
1✔
142
                    return unpickle, (cls, self.name, tuple(getattr(variant, f"_{i}") for i in variant.__fields__))
1✔
143

144
                def dump(self) -> tuple:
1✔
145
                    return tuple(getattr(self, f"_{name}") for name in self.__fields__)
1✔
146

147
                def __init__(self, *args) -> None:
1✔
148
                    if len(tuple_field) != len(args):
1✔
149
                        raise TypeError(f"Expect {len(tuple_field)} field(s), but received {len(args)} argument(s).")
×
150

151
                    for name, field, value in zip(item._slots_names, tuple_field, args, strict=True):
1✔
152
                        setattr(self, name, value)
1✔
153

154
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
155
                    post_init()
1✔
156

157
            self._actual = TupleConstructedVariant
1✔
158

159
        elif named_field:
1✔
160
            class NamedConstructedVariant(ConstructedVariant):
1✔
161
                __name__ = item.name
1✔
162
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
163
                __fields__ = item._slots_names
1✔
164
                __defaults__ = item._defaults
1✔
165
                __factories__ = item._default_factories
1✔
166
                __slots__ = ()
1✔
167
                if not item._kw_only:
1✔
168
                    __match_args__ = item._slots_names
1✔
169

170
                if build_hash:
1✔
171
                    __slots__ += ("_hash",)
1✔
172

173
                    if frozen:
1✔
174
                        def __hash__(self) -> int:
1✔
175
                            with suppress(AttributeError):
1✔
176
                                return self._hash
1✔
177

178
                            self._hash = hash(tuple(self.dump().items()))
1✔
179
                            return self._hash
1✔
180
                    else:
181
                        __hash__ = None  # type: ignore
1✔
182

183
                if eq:
1✔
184
                    def __eq__(self, other: typing.Self):
1✔
185
                        return type(self) is type(other) and self.dump() == other.dump()
×
186

187
                @staticmethod
1✔
188
                def _pickle(variant):
1✔
189
                    assert isinstance(variant, ConstructedVariant)
1✔
190
                    return unpickle, (cls, self.name, {name: getattr(variant, name) for name in variant.__fields__})
1✔
191

192
                def dump(self):
1✔
193
                    return {name: getattr(self, name) for name in self.__fields__}
1✔
194

195
                if eq:
1✔
196
                    def __eq__(self, other: typing.Self):
1✔
197
                        return type(self) is type(other) and self.dump() == other.dump()
1✔
198

199
                def __repr__(self) -> str:
1✔
200
                    values_repr = ', '.join(f'{name}={getattr(self, f"_{name}" if isinstance(name, int) else name)!r}' for name in self.__fields__)
1✔
201
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
202

203
                def __init__(self, *args, **kwargs) -> None:
1✔
204
                    if args:
1✔
205
                        if item._kw_only:
1✔
206
                            raise TypeError(f"Variant '{type(self).__qualname__}' is keyword only.")
1✔
207

208
                        if len(args) > len(named_field_keys):
1✔
209
                            raise TypeError(f"{self.__name__} takes {len(named_field_keys)} positional argument(s) but {len(args)} were/was given")
1✔
210

211
                        # a valid use case of zip without strict=True
212
                        for arg, field_name in zip(args, named_field_keys):
1✔
213
                            if field_name in kwargs:
1✔
214
                                raise TypeError(f"Inconsistent input for field '{field_name}': received both positional and keyword values")
1✔
215
                            kwargs[field_name] = arg
1✔
216

217
                    if self.__defaults__:
1✔
218
                        kwargs = self.__defaults__ | kwargs
1✔
219

220
                    if self.__factories__:
1✔
221
                        for name, factory in self.__factories__.items():
1✔
222
                            if name not in kwargs:
1✔
223
                                kwargs[name] = factory()
1✔
224

225
                    if missed_keys := kwargs.keys() ^ named_field.keys():
1✔
226
                        raise TypeError(f"Key mismatch: {missed_keys}")
×
227

228
                    for name in named_field:
1✔
229
                        value = kwargs[name]
1✔
230
                        # field = named_field[name]
231
                        setattr(self, name, value)
1✔
232

233
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
234
                    post_init()
1✔
235

236
            self._actual = NamedConstructedVariant
1✔
237

238
        else:
239
            class FieldlessConstructedVariant(ConstructedVariant, metaclass=ParamlessSingletonMeta):
1✔
240
                __name__ = item.name
1✔
241
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
242
                __fields__ = ()
1✔
243
                __slots__ = ()
1✔
244

245
                if build_hash and not frozen:
1✔
246
                    __hash__ = None  # type: ignore
1✔
247
                else:
248
                    def __hash__(self):
1✔
249
                        return hash(id(self))
1✔
250

251
                @staticmethod
1✔
252
                def _pickle(variant):
1✔
253
                    assert isinstance(variant, ConstructedVariant)
1✔
254
                    return unpickle, (cls, self.name, ())
1✔
255

256
                def dump(self):
1✔
257
                    return ()
1✔
258

259
                def __repr__(self) -> str:
1✔
260
                    values_repr = ""
1✔
261
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
262

263
                def __init__(self) -> None:
1✔
264
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
265
                    post_init()
1✔
266

267
            self._actual = FieldlessConstructedVariant
1✔
268
        # fmt: on
269

270
        copyreg.pickle(self._actual, self._actual._pickle)
1✔
271
        self.attached = True
1✔
272

273
    def __call__(self, *args, **kwargs):
1✔
274
        return self._actual(*args, **kwargs)
×
275

276

277
def variant(cls=None, kw_only: bool = False):
1✔
278
    if cls is None:
1✔
NEW
279
        return lambda cls: variant(cls, kw_only)
×
280

281
    fields = cls.__annotations__
1✔
282
    defaults = {field_name: getattr(cls, field_name) for field_name in fields if hasattr(cls, field_name)}
1✔
283

284
    constructed = Variant(**fields).default(**defaults)
1✔
285
    if kw_only:
1✔
NEW
286
        constructed = constructed.kw_only()
×
287
    return constructed
1✔
288

289

290
# MARK: UnitDescriptor
291

292

293
class UnitDescriptor:
1✔
294
    __slots__ = ("name",)
1✔
295
    __fields__ = ()
1✔
296

297
    def __init__(self, name: str | None = None):
1✔
298
        self.name = name
1✔
299

300
    def __set_name__(self, owner, name):
1✔
301
        setattr(owner, name, UnitDescriptor(name))
1✔
302

303
    @typing.overload
1✔
304
    def __get__(self, obj, objtype: type[T] = ...) -> T: ...  # type: ignore
1✔
305

306
    @typing.overload
1✔
307
    def __get__(self, obj, objtype: None = ...) -> typing.Self: ...
1✔
308

309
    def __get__(self, obj, objtype: type[T] | None = None) -> T | typing.Self:
1✔
UNCOV
310
        return self
×
311

312
    def attach(
1✔
313
        self,
314
        cls,
315
        /,
316
        *,
317
        eq: bool,  # not needed since nothing to check equality
318
        build_hash: bool,
319
        frozen: bool,
320
    ):
321
        if self.name is None:
1✔
322
            raise TypeError("`self.name` is not set.")
1✔
323

324
        class UnitConstructedVariant(cls, metaclass=ParamlessSingletonMeta):
1✔
325
            __name__ = self.name
1✔
326
            __slots__ = ()
1✔
327
            __fields__ = None  # `None` means it does not require calling for initialize.
1✔
328

329
            if build_hash and not frozen:
1✔
330
                __hash__ = None  # type: ignore # Explicitly disable hash
1✔
331
            else:
332
                def __hash__(self):
1✔
333
                    return hash(id(self))
1✔
334

335
            def dump(self):
1✔
336
                return None
1✔
337

338
            @staticmethod
1✔
339
            def _pickle(variant):
1✔
340
                assert isinstance(variant, UnitConstructedVariant)
1✔
341
                return unpickle, (cls, self.name, None)
1✔
342

343
            def __init__(self):
1✔
344
                pass
1✔
345

346
            def __repr__(self):
1✔
347
                return f"{cls.__name__}.{self.__name__}"
1✔
348

349
        copyreg.pickle(UnitConstructedVariant, UnitConstructedVariant._pickle)
1✔
350

351
        # This will replace Unit to Specialized instance.
352
        setattr(cls, self.name, UnitConstructedVariant())
1✔
353

354

355
Unit = UnitDescriptor()
1✔
356

357

358
def fieldenum(
1✔
359
    cls=None,
360
    /,
361
    *,
362
    eq: bool = True,
363
    frozen: bool = True,
364
):
365
    if cls is None:
1✔
366
        return lambda cls: fieldenum(
1✔
367
            cls,
368
            eq=eq,
369
            frozen=frozen,
370
        )
371

372
    # Preventing subclassing fieldenums at runtime.
373
    # This also prevent double decoration.
374
    is_final = False
1✔
375
    for base in cls.mro()[1:]:
1✔
376
        with suppress(Exception):
1✔
377
            if base.__final__:
1✔
378
                is_final = True
1✔
379
                break
1✔
380
    if is_final:
1✔
381
        raise TypeError(
1✔
382
            "One of the base classes of fieldenum class is marked as final, "
383
            "which means it does not want to be subclassed and it may be fieldenum class, "
384
            "which should not be subclassed."
385
        )
386

387
    class_attributes = vars(cls)
1✔
388
    has_own_hash = "__hash__" in class_attributes
1✔
389
    build_hash = eq and not has_own_hash
1✔
390

391
    for attr in class_attributes.values():
1✔
392
        if isinstance(attr, Variant | UnitDescriptor):
1✔
393
            attr.attach(
1✔
394
                cls,
395
                eq=eq,
396
                build_hash=build_hash,
397
                frozen=frozen,
398
            )
399

400
    with suppress(Exception):
1✔
401
        cls.__init__ = NotAllowed("Base fieldenums cannot be initialized.", name="__init__")
1✔
402

403
    return typing.final(cls)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc