• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ilotoki0804 / fieldenum / 10788611744

10 Sep 2024 08:22AM UTC coverage: 95.71% (+0.7%) from 95.005%
10788611744

push

github

ilotoki0804
Add `__variants__` dunder

5 of 5 new or added lines in 1 file covered. (100.0%)

13 existing lines in 3 files now uncovered.

1160 of 1212 relevant lines covered (95.71%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.05
/src/fieldenum/_fieldenum.py
1
"""fieldenum core implementation."""
1✔
2

3
from __future__ import annotations
1✔
4

5
import copyreg
1✔
6
import inspect
1✔
7
import types
1✔
8
import typing
1✔
9
from contextlib import suppress
1✔
10

11
from ._utils import NotAllowed, OneTimeSetter, ParamlessSingletonMeta, unpickle
1✔
12
from .exceptions import unreachable
13

14
T = typing.TypeVar("T")
1✔
15

16

17
class Variant:  # MARK: Variant
1✔
18
    __slots__ = (
1✔
19
        "name",
20
        "field",
21
        "attached",
22
        "_slots_names",
23
        "_base",
24
        "_generics",
25
        "_actual",
26
        "_defaults_and_factories",
27
        "_kw_only",
28
    )
29

30
    def __set_name__(self, owner, name) -> None:
1✔
31
        if self.attached:
1✔
32
            raise TypeError(f"This variants already attached to {self._base.__name__!r}.")
1✔
33
        self.name = name
1✔
34

35
    def __get__(self, obj, objtype=None) -> typing.Self:
1✔
36
        if self.attached:
1✔
37
            # This is needed in order to make match statements work.
38
            return self._actual  # type: ignore
1✔
39

UNCOV
40
        return self
×
41

42
    def kw_only(self) -> typing.Self:
1✔
43
        self._kw_only = True
1✔
44
        return self
1✔
45

46
    # fieldless variant
47
    @typing.overload
1✔
48
    def __init__(self) -> None: ...
1✔
49

50
    # tuple variant
51
    @typing.overload
1✔
52
    def __init__(self, *tuple_field) -> None: ...
1✔
53

54
    # named variant
55
    @typing.overload
1✔
56
    def __init__(self, **named_field) -> None: ...
1✔
57

58
    def __init__(self, *tuple_field, **named_field) -> None:
1✔
59
        self.attached = False
1✔
60
        self._kw_only = False
1✔
61
        self._defaults_and_factories = {}
1✔
62
        if tuple_field and named_field:
1✔
63
            raise TypeError("Cannot mix tuple fields and named fields. Use named fields.")
1✔
64
        self.field = (tuple_field, named_field)
1✔
65
        if named_field:
1✔
66
            self._slots_names = tuple(named_field)
1✔
67
        else:
68
            self._slots_names = tuple(f"_{i}" for i in range(len(tuple_field)))
1✔
69

70
    if typing.TYPE_CHECKING:
71
        def dump(self): ...
72

73
    def default(self, **defaults_and_factories) -> typing.Self:
1✔
74
        _, named_field = self.field
1✔
75
        if not named_field:
1✔
76
            raise TypeError("Only named variants can have defaults.")
1✔
77

78
        self._defaults_and_factories.update(defaults_and_factories)
1✔
79
        return self
1✔
80

81
    def attach(
1✔
82
        self,
83
        cls,
84
        /,
85
        *,
86
        eq: bool,
87
        build_hash: bool,
88
        frozen: bool,
89
    ) -> None | typing.Self:
90
        if self.attached:
1✔
91
            raise TypeError(f"This variants already attached to {self._base.__name__!r}.")
1✔
92

93
        self._base = cls
1✔
94
        tuple_field, named_field = self.field
1✔
95
        if not self._kw_only:
1✔
96
            named_field_keys = tuple(named_field)
1✔
97
        item = self
1✔
98

99
        self._actual: ConstructedVariant
1✔
100

101
        # fmt: off
102
        class ConstructedVariant(cls):
1✔
103
            if frozen and not typing.TYPE_CHECKING:
1✔
104
                __slots__ = tuple(f"__original_{name}" for name in item._slots_names)
1✔
105
                for name in item._slots_names:
1✔
106
                    # to prevent potential security risk
107
                    if name.isidentifier():
1✔
108
                        exec(f"{name} = OneTimeSetter()")
1✔
109
                    else:
110
                        unreachable(name)
UNCOV
111
                        OneTimeSetter()  # Show IDEs that OneTimeSetter is used. Not executed at runtime.
×
112
            else:
113
                __slots__ = item._slots_names
1✔
114

115
        if tuple_field:
1✔
116
            class TupleConstructedVariant(ConstructedVariant):
1✔
117
                __name__ = item.name
1✔
118
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
119
                __fields__ = tuple(range(len(tuple_field)))
1✔
120
                __slots__ = ()
1✔
121
                __match_args__ = item._slots_names
1✔
122

123
                if build_hash:
1✔
124
                    __slots__ += ("_hash",)
1✔
125

126
                    if frozen:
1✔
127
                        def __hash__(self) -> int:
1✔
128
                            with suppress(AttributeError):
1✔
129
                                return self._hash
1✔
130

131
                            self._hash = hash(self.dump())
1✔
132
                            return self._hash
1✔
133
                    else:
134
                        __hash__ = None  # type: ignore
1✔
135

136
                if eq:
1✔
137
                    def __eq__(self, other: typing.Self):
1✔
138
                        return type(self) is type(other) and self.dump() == other.dump()
1✔
139

140
                def __repr__(self) -> str:
1✔
141
                    values_repr = ", ".join(repr(getattr(self, f"_{name}" if isinstance(name, int) else name)) for name in self.__fields__)
1✔
142
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
143

144
                @staticmethod
1✔
145
                def _pickle(variant):
1✔
146
                    assert isinstance(variant, ConstructedVariant)
1✔
147
                    return unpickle, (cls, self.name, tuple(getattr(variant, f"_{i}") for i in variant.__fields__), {})
1✔
148

149
                def dump(self) -> tuple:
1✔
150
                    return tuple(getattr(self, f"_{name}") for name in self.__fields__)
1✔
151

152
                def __init__(self, *args) -> None:
1✔
153
                    if len(tuple_field) != len(args):
1✔
154
                        raise TypeError(f"Expect {len(tuple_field)} field(s), but received {len(args)} argument(s).")
1✔
155

156
                    for name, field, value in zip(item._slots_names, tuple_field, args, strict=True):
1✔
157
                        setattr(self, name, value)
1✔
158

159
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
160
                    post_init()
1✔
161

162
            self._actual = TupleConstructedVariant
1✔
163

164
        elif named_field:
1✔
165
            class NamedConstructedVariant(ConstructedVariant):
1✔
166
                __name__ = item.name
1✔
167
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
168
                __fields__ = item._slots_names
1✔
169
                __slots__ = ()
1✔
170
                if not item._kw_only:
1✔
171
                    __match_args__ = item._slots_names
1✔
172

173
                if build_hash:
1✔
174
                    __slots__ += ("_hash",)
1✔
175

176
                    if frozen:
1✔
177
                        def __hash__(self) -> int:
1✔
178
                            with suppress(AttributeError):
1✔
179
                                return self._hash
1✔
180

181
                            self._hash = hash(tuple(self.dump().items()))
1✔
182
                            return self._hash
1✔
183
                    else:
184
                        __hash__ = None  # type: ignore
1✔
185

186
                if eq:
1✔
187
                    def __eq__(self, other: typing.Self):
1✔
188
                        return type(self) is type(other) and self.dump() == other.dump()
1✔
189

190
                @staticmethod
1✔
191
                def _pickle(variant):
1✔
192
                    assert isinstance(variant, ConstructedVariant)
1✔
193
                    return unpickle, (cls, self.name, (), {name: getattr(variant, name) for name in variant.__fields__})
1✔
194

195
                def dump(self):
1✔
196
                    return {name: getattr(self, name) for name in self.__fields__}
1✔
197

198
                def __repr__(self) -> str:
1✔
199
                    values_repr = ', '.join(f'{name}={getattr(self, f"_{name}" if isinstance(name, int) else name)!r}' for name in self.__fields__)
1✔
200
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
201

202
                def __init__(self, *args, **kwargs) -> None:
1✔
203
                    if args:
1✔
204
                        if item._kw_only:
1✔
205
                            raise TypeError(f"Variant '{type(self).__qualname__}' is keyword only.")
1✔
206

207
                        if len(args) > len(named_field_keys):
1✔
208
                            raise TypeError(f"{self.__name__} takes {len(named_field_keys)} positional argument(s) but {len(args)} were/was given")
1✔
209

210
                        # a valid use case of zip without strict=True
211
                        for arg, field_name in zip(args, named_field_keys):
1✔
212
                            if field_name in kwargs:
1✔
213
                                raise TypeError(f"Inconsistent input for field '{field_name}': received both positional and keyword values")
1✔
214
                            kwargs[field_name] = arg
1✔
215

216
                    if item._defaults_and_factories:
1✔
217
                        for name, default_or_factory in item._defaults_and_factories.items():
1✔
218
                            if name not in kwargs:
1✔
219
                                kwargs[name] = factory._produce_from(default_or_factory)
1✔
220

221
                    if missed_keys := kwargs.keys() ^ named_field.keys():
1✔
222
                        raise TypeError(f"Key mismatch: {missed_keys}")
1✔
223

224
                    for name in named_field:
1✔
225
                        value = kwargs[name]
1✔
226
                        # field = named_field[name]
227
                        setattr(self, name, value)
1✔
228

229
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
230
                    post_init()
1✔
231

232
            self._actual = NamedConstructedVariant
1✔
233

234
        else:
235
            class FieldlessConstructedVariant(ConstructedVariant, metaclass=ParamlessSingletonMeta):
1✔
236
                __name__ = item.name
1✔
237
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
238
                __fields__ = ()
1✔
239
                __slots__ = ()
1✔
240

241
                if build_hash and not frozen:
1✔
242
                    __hash__ = None  # type: ignore
1✔
243
                else:
244
                    def __hash__(self):
1✔
245
                        return hash(id(self))
1✔
246

247
                @staticmethod
1✔
248
                def _pickle(variant):
1✔
249
                    assert isinstance(variant, ConstructedVariant)
1✔
250
                    return unpickle, (cls, self.name, (), {})
1✔
251

252
                def dump(self):
1✔
253
                    return ()
1✔
254

255
                def __repr__(self) -> str:
1✔
256
                    values_repr = ""
1✔
257
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
258

259
                def __init__(self) -> None:
1✔
260
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
261
                    post_init()
1✔
262

263
            self._actual = FieldlessConstructedVariant
1✔
264
        # fmt: on
265

266
        copyreg.pickle(self._actual, self._actual._pickle)
1✔
267
        self.attached = True
1✔
268

269
    def __call__(self, *args, **kwargs):
1✔
270
        return self._actual(*args, **kwargs)
×
271

272

273
POSITIONALS = (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
1✔
274

275

276
class _FunctionVariant(Variant):  # MARK: FunctionVariant
1✔
277
    __slots__ = ("_func", "_signature", "_match_args", "_self_included")
1✔
278
    name: str
1✔
279

280
    def __init__(self, func: types.FunctionType) -> None:
1✔
281
        assert type(func) is types.FunctionType, "Type other than function is not allowed."
1✔
282
        self.attached = False
1✔
283
        self._func = func
1✔
284
        signature = inspect.signature(func)
1✔
285
        parameters_raw = signature.parameters
1✔
286
        self._signature = signature
1✔
287

288
        parameters_iter = iter(parameters_raw)
1✔
289
        self._self_included = next(parameters_iter) == "self"
1✔
290

291
        parameter_names = tuple(parameters_iter) if self._self_included else tuple(parameters_raw)
1✔
292
        self._slots_names = parameter_names
1✔
293
        self.field = ((), parameter_names)
1✔
294
        self._match_args = tuple(
1✔
295
            name for name in parameter_names
296
            if parameters_raw[name].kind in POSITIONALS
297
        )
298

299
    def kw_only(self) -> typing.NoReturn:
1✔
300
        raise TypeError(
1✔
301
            "`.kw_only()` method cannot be used in function variant. "
302
            "Use function keyword-only specifier(asterisk) instead."
303
        )
304

305
    def default(self, **_) -> typing.NoReturn:
1✔
306
        raise TypeError(
1✔
307
            "`.default()` method cannot be used in function variant. "
308
            "Use function defaults instead."
309
        )
310

311
    def attach(
1✔
312
        self,
313
        cls,
314
        /,
315
        *,
316
        eq: bool,
317
        build_hash: bool,
318
        frozen: bool,
319
    ) -> None | typing.Self:
320
        if self.attached:
1✔
UNCOV
321
            raise TypeError(f"This variants already attached to {self._base.__name__!r}.")
×
322

323
        self._base = cls
1✔
324
        item = self
1✔
325

326
        # fmt: off
327
        class ConstructedVariant(cls):
1✔
328
            if frozen and not typing.TYPE_CHECKING:
1✔
329
                __slots__ = tuple(f"__original_{name}" for name in item._slots_names)
1✔
330
                for name in item._slots_names:
1✔
331
                    # to prevent potential security risk
332
                    if name.isidentifier():
1✔
333
                        exec(f"{name} = OneTimeSetter()")
1✔
334
                    else:
335
                        unreachable(name)
UNCOV
336
                        OneTimeSetter()  # Show IDEs that OneTimeSetter is used. Not executed at runtime.
×
337
            else:
338
                __slots__ = item._slots_names
1✔
339

340
            __name__ = item.name
1✔
341
            __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
342
            __fields__ = item._slots_names
1✔
343
            __match_args__ = item._match_args
1✔
344

345
            if build_hash:
1✔
346
                __slots__ += ("_hash",)
1✔
347

348
                if frozen:
1✔
349
                    def __hash__(self) -> int:
1✔
350
                        with suppress(AttributeError):
1✔
351
                            return self._hash
1✔
352

353
                        self._hash = hash(tuple(self.dump().items()))
1✔
354
                        return self._hash
1✔
355
                else:
356
                    __hash__ = None  # type: ignore
1✔
357

358
            if eq:
1✔
359
                def __eq__(self, other: typing.Self):
1✔
UNCOV
360
                    return type(self) is type(other) and self.dump() == other.dump()
×
361

362
            def _get_positions(self) -> tuple[dict[str, typing.Any], dict[str, typing.Any]]:
1✔
363
                match_args = self.__match_args__
1✔
364
                args_dict = {}
1✔
365
                kwargs = {}
1✔
366
                for name in self.__fields__:
1✔
367
                    if name in match_args:
1✔
368
                        args_dict[name] = getattr(self, name)
1✔
369
                    else:
370
                        kwargs[name] = getattr(self, name)
1✔
371
                return args_dict, kwargs
1✔
372

373
            @staticmethod
1✔
374
            def _pickle(variant):
1✔
375
                assert isinstance(variant, ConstructedVariant)
1✔
376
                args_dict, kwargs = variant._get_positions()
1✔
377
                return unpickle, (cls, self.name, tuple(args_dict.values()), kwargs)
1✔
378

379
            def dump(self):
1✔
380
                return {name: getattr(self, name) for name in self.__fields__}
1✔
381

382
            if eq:
1✔
383
                def __eq__(self, other: typing.Self):
1✔
384
                    return type(self) is type(other) and self.dump() == other.dump()
1✔
385

386
            def __repr__(self) -> str:
1✔
387
                args_dict, kwargs = self._get_positions()
1✔
388
                args_repr = ", ".join(repr(value) for value in args_dict.values())
1✔
389
                kwargs_repr = ", ".join(
1✔
390
                    f'{name}={value!r}'
391
                    for name, value in kwargs.items()
392
                )
393

394
                if args_repr and kwargs_repr:
1✔
395
                    values_repr = f"{args_repr}, {kwargs_repr}"
1✔
396
                else:
397
                    values_repr = f"{args_repr}{kwargs_repr}"
1✔
398

399
                return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
400

401
            def __init__(self, *args, **kwargs) -> None:
1✔
402
                bound = (
1✔
403
                    item._signature.bind(None, *args, **kwargs)
404
                    if item._self_included
405
                    else item._signature.bind(*args, **kwargs)
406
                )
407

408
                # code from Signature.apply_defaults()
409
                arguments = bound.arguments
1✔
410
                new_arguments = {}
1✔
411
                for name, param in item._signature.parameters.items():
1✔
412
                    try:
1✔
413
                        value = arguments[name]
1✔
414
                    except KeyError:
1✔
415
                        assert param.default is not inspect._empty, "Argument is not properly bound."
1✔
416
                        value = param.default
1✔
417
                    new_arguments[name] = factory._produce_from(value)
1✔
418
                bound.arguments = new_arguments  # type: ignore # why not OrderedDict? I don't know
1✔
419

420
                for name, value in new_arguments.items():
1✔
421
                    if item._self_included and name == "self":
1✔
422
                        continue
1✔
423
                    setattr(self, name, value)
1✔
424

425
                if item._self_included:
1✔
426
                    bound.arguments["self"] = self
1✔
427
                    result = item._func(*bound.args, **bound.kwargs)
1✔
428
                    if result is not None:
1✔
429
                        raise TypeError("Initializer should return None.")
1✔
430
        # fmt: on
431

432
        self._actual = ConstructedVariant
1✔
433
        copyreg.pickle(self._actual, self._actual._pickle)
1✔
434
        self.attached = True
1✔
435

436

437
@typing.overload
1✔
438
def variant(cls: type, /) -> Variant: ...
1✔
439

440
@typing.overload
1✔
441
def variant(*, kw_only: bool = False) -> typing.Callable[[type], Variant]: ...
1✔
442

443
@typing.overload
1✔
444
def variant(func: types.FunctionType, /) -> Variant: ...
1✔
445

446
def variant(cls_or_func=None, /, *, kw_only: bool = False) -> typing.Any:  # MARK: variant
1✔
447
    if cls_or_func is None:
1✔
448
        return lambda cls_or_func: variant(cls_or_func, kw_only=kw_only)  # type: ignore
1✔
449

450
    if isinstance(cls_or_func, types.FunctionType):
1✔
451
        constructed = _FunctionVariant(cls_or_func)
1✔
452

453
    else:
454
        fields = cls_or_func.__annotations__
1✔
455
        defaults = {
1✔
456
            field_name: getattr(cls_or_func, field_name)
457
            for field_name in fields if hasattr(cls_or_func, field_name)
458
        }
459

460
        constructed = Variant(**fields).default(**defaults)
1✔
461
        if kw_only:
1✔
462
            constructed = constructed.kw_only()
1✔
463

464
    return constructed
1✔
465

466

467
class factory(typing.Generic[T]):  # MARK: factory
1✔
468
    def __init__(self, func: typing.Callable[[], T]):
1✔
469
        self.__factory = func
1✔
470

471
    @classmethod
1✔
472
    def _produce_from(cls, value: factory[T] | T) -> T:
1✔
473
        return value.produce() if isinstance(value, factory) else value  # type: ignore
1✔
474

475
    def produce(self) -> T:
1✔
476
        return self.__factory()
1✔
477

478

479
class UnitDescriptor:  # MARK: Unit
1✔
480
    __slots__ = ("name",)
1✔
481
    __fields__ = ()
1✔
482

483
    def __init__(self, name: str | None = None):
1✔
484
        self.name = name
1✔
485

486
    def __set_name__(self, owner, name):
1✔
487
        setattr(owner, name, UnitDescriptor(name))
1✔
488

489
    @typing.overload
1✔
490
    def __get__(self, obj, objtype: type[T] = ...) -> T: ...  # type: ignore
1✔
491

492
    @typing.overload
1✔
493
    def __get__(self, obj, objtype: None = ...) -> typing.Self: ...
1✔
494

495
    def __get__(self, obj, objtype: type[T] | None = None) -> T | typing.Self:
1✔
UNCOV
496
        return self
×
497

498
    def attach(
1✔
499
        self,
500
        cls,
501
        /,
502
        *,
503
        eq: bool,  # not needed since nothing to check equality
504
        build_hash: bool,
505
        frozen: bool,
506
    ):
507
        if self.name is None:
1✔
508
            raise TypeError("`self.name` is not set.")
1✔
509

510
        class UnitConstructedVariant(cls, metaclass=ParamlessSingletonMeta):
1✔
511
            __name__ = self.name
1✔
512
            __slots__ = ()
1✔
513
            __fields__ = None  # `None` means it does not require calling for initialize.
1✔
514

515
            if build_hash and not frozen:
1✔
516
                __hash__ = None  # type: ignore # Explicitly disable hash
1✔
517
            else:
518
                def __hash__(self):
1✔
519
                    return hash(id(self))
1✔
520

521
            def dump(self):
1✔
522
                return None
1✔
523

524
            @staticmethod
1✔
525
            def _pickle(variant):
1✔
526
                assert isinstance(variant, UnitConstructedVariant)
1✔
527
                return unpickle, (cls, self.name, None, None)
1✔
528

529
            def __init__(self):
1✔
530
                pass
1✔
531

532
            def __repr__(self):
1✔
533
                return f"{cls.__name__}.{self.__name__}"
1✔
534

535
        copyreg.pickle(UnitConstructedVariant, UnitConstructedVariant._pickle)
1✔
536

537
        # This will replace Unit to Specialized instance.
538
        setattr(cls, self.name, UnitConstructedVariant())
1✔
539

540

541
Unit = UnitDescriptor()
1✔
542

543

544
def fieldenum(
1✔
545
    cls=None,
546
    /,
547
    *,
548
    eq: bool = True,
549
    frozen: bool = True,
550
):
551
    if cls is None:
1✔
552
        return lambda cls: fieldenum(
1✔
553
            cls,
554
            eq=eq,
555
            frozen=frozen,
556
        )
557

558
    # Preventing subclassing fieldenums at runtime.
559
    # This also prevent double decoration.
560
    is_final = False
1✔
561
    for base in cls.mro()[1:]:
1✔
562
        with suppress(Exception):
1✔
563
            if base.__final__:
1✔
564
                is_final = True
1✔
565
                break
1✔
566
    if is_final:
1✔
567
        raise TypeError(
1✔
568
            "One of the base classes of fieldenum class is marked as final, "
569
            "which means it does not want to be subclassed and it may be fieldenum class, "
570
            "which should not be subclassed."
571
        )
572

573
    class_attributes = vars(cls)
1✔
574
    has_own_hash = "__hash__" in class_attributes
1✔
575
    build_hash = eq and not has_own_hash
1✔
576

577
    attrs = []
1✔
578
    for name, attr in class_attributes.items():
1✔
579
        if isinstance(attr, Variant | UnitDescriptor):
1✔
580
            attr.attach(
1✔
581
                cls,
582
                eq=eq,
583
                build_hash=build_hash,
584
                frozen=frozen,
585
            )
586
            attrs.append(name)
1✔
587

588
    cls.__variants__ = attrs
1✔
589
    cls.__init__ = NotAllowed("A base fieldenum cannot be initialized.", name="__init__")
1✔
590

591
    return typing.final(cls)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc