• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ilotoki0804 / fieldenum / 10862198381

14 Sep 2024 11:59AM UTC coverage: 95.652% (-0.06%) from 95.71%
10862198381

push

github

ilotoki0804
Add `BoundResult.exit()`

3 of 4 new or added lines in 1 file covered. (75.0%)

6 existing lines in 1 file now uncovered.

1166 of 1219 relevant lines covered (95.65%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.07
/src/fieldenum/_fieldenum.py
1
"""fieldenum core implementation."""
1✔
2

3
from __future__ import annotations
1✔
4

5
import copyreg
1✔
6
import inspect
1✔
7
import types
1✔
8
import typing
1✔
9
from contextlib import suppress
1✔
10

11
from ._utils import NotAllowed, OneTimeSetter, ParamlessSingletonMeta, unpickle
1✔
12
from .exceptions import unreachable
13

14
T = typing.TypeVar("T")
1✔
15

16

17
class Variant:  # MARK: Variant
1✔
18
    __slots__ = (
1✔
19
        "name",
20
        "field",
21
        "attached",
22
        "_slots_names",
23
        "_base",
24
        "_generics",
25
        "_actual",
26
        "_defaults_and_factories",
27
        "_kw_only",
28
    )
29

30
    def __set_name__(self, owner, name) -> None:
1✔
31
        if self.attached:
1✔
32
            raise TypeError(f"This variants already attached to {self._base.__name__!r}.")
1✔
33
        self.name = name
1✔
34

35
    def __get__(self, obj, objtype=None) -> typing.Self:
1✔
36
        if self.attached:
1✔
37
            # This is needed in order to make match statements work.
38
            return self._actual  # type: ignore
1✔
39

40
        return self
×
41

42
    def kw_only(self) -> typing.Self:
1✔
43
        self._kw_only = True
1✔
44
        return self
1✔
45

46
    # fieldless variant
47
    @typing.overload
1✔
48
    def __init__(self) -> None: ...
1✔
49

50
    # tuple variant
51
    @typing.overload
1✔
52
    def __init__(self, *tuple_field) -> None: ...
1✔
53

54
    # named variant
55
    @typing.overload
1✔
56
    def __init__(self, **named_field) -> None: ...
1✔
57

58
    def __init__(self, *tuple_field, **named_field) -> None:
1✔
59
        self.attached = False
1✔
60
        self._kw_only = False
1✔
61
        self._defaults_and_factories = {}
1✔
62
        if tuple_field and named_field:
1✔
63
            raise TypeError("Cannot mix tuple fields and named fields. Use named fields.")
1✔
64
        self.field = (tuple_field, named_field)
1✔
65
        if named_field:
1✔
66
            self._slots_names = tuple(named_field)
1✔
67
        else:
68
            self._slots_names = tuple(f"_{i}" for i in range(len(tuple_field)))
1✔
69

70
    if typing.TYPE_CHECKING:
71
        def dump(self): ...
72

73
    def default(self, **defaults_and_factories) -> typing.Self:
1✔
74
        _, named_field = self.field
1✔
75
        if not named_field:
1✔
76
            raise TypeError("Only named variants can have defaults.")
1✔
77

78
        self._defaults_and_factories.update(defaults_and_factories)
1✔
79
        return self
1✔
80

81
    def attach(
1✔
82
        self,
83
        cls,
84
        /,
85
        *,
86
        eq: bool,
87
        build_hash: bool,
88
        build_repr: bool,
89
        frozen: bool,
90
    ) -> None | typing.Self:
91
        if self.attached:
1✔
92
            raise TypeError(f"This variants already attached to {self._base.__name__!r}.")
1✔
93

94
        self._base = cls
1✔
95
        tuple_field, named_field = self.field
1✔
96
        if not self._kw_only:
1✔
97
            named_field_keys = tuple(named_field)
1✔
98
        item = self
1✔
99

100
        self._actual: ConstructedVariant
1✔
101

102
        # fmt: off
103
        class ConstructedVariant(cls):
1✔
104
            if frozen and not typing.TYPE_CHECKING:
1✔
105
                __slots__ = tuple(f"__original_{name}" for name in item._slots_names)
1✔
106
                for name in item._slots_names:
1✔
107
                    # to prevent potential security risk
108
                    if name.isidentifier():
1✔
109
                        exec(f"{name} = OneTimeSetter()")
1✔
110
                    else:
111
                        unreachable(name)
UNCOV
112
                        OneTimeSetter()  # Show IDEs that OneTimeSetter is used. Not executed at runtime.
×
113
            else:
114
                __slots__ = item._slots_names
1✔
115

116
        if tuple_field:
1✔
117
            class TupleConstructedVariant(ConstructedVariant):
1✔
118
                __name__ = item.name
1✔
119
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
120
                __fields__ = tuple(range(len(tuple_field)))
1✔
121
                __slots__ = ()
1✔
122
                __match_args__ = item._slots_names
1✔
123

124
                if build_hash:
1✔
125
                    __slots__ += ("_hash",)
1✔
126

127
                    if frozen:
1✔
128
                        def __hash__(self) -> int:
1✔
129
                            with suppress(AttributeError):
1✔
130
                                return self._hash
1✔
131

132
                            self._hash = hash(self.dump())
1✔
133
                            return self._hash
1✔
134
                    else:
135
                        __hash__ = None  # type: ignore
1✔
136

137
                if eq:
1✔
138
                    def __eq__(self, other: typing.Self):
1✔
139
                        return type(self) is type(other) and self.dump() == other.dump()
1✔
140

141
                if build_repr:
1✔
142
                    def __repr__(self) -> str:
1✔
143
                        values_repr = ", ".join(repr(getattr(self, f"_{name}" if isinstance(name, int) else name)) for name in self.__fields__)
1✔
144
                        return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
145

146
                @staticmethod
1✔
147
                def _pickle(variant):
1✔
148
                    assert isinstance(variant, ConstructedVariant)
1✔
149
                    return unpickle, (cls, self.name, tuple(getattr(variant, f"_{i}") for i in variant.__fields__), {})
1✔
150

151
                def dump(self) -> tuple:
1✔
152
                    return tuple(getattr(self, f"_{name}") for name in self.__fields__)
1✔
153

154
                def __init__(self, *args) -> None:
1✔
155
                    if len(tuple_field) != len(args):
1✔
156
                        raise TypeError(f"Expect {len(tuple_field)} field(s), but received {len(args)} argument(s).")
1✔
157

158
                    for name, field, value in zip(item._slots_names, tuple_field, args, strict=True):
1✔
159
                        setattr(self, name, value)
1✔
160

161
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
162
                    post_init()
1✔
163

164
            self._actual = TupleConstructedVariant
1✔
165

166
        elif named_field:
1✔
167
            class NamedConstructedVariant(ConstructedVariant):
1✔
168
                __name__ = item.name
1✔
169
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
170
                __fields__ = item._slots_names
1✔
171
                __slots__ = ()
1✔
172
                if not item._kw_only:
1✔
173
                    __match_args__ = item._slots_names
1✔
174

175
                if build_hash:
1✔
176
                    __slots__ += ("_hash",)
1✔
177

178
                    if frozen:
1✔
179
                        def __hash__(self) -> int:
1✔
180
                            with suppress(AttributeError):
1✔
181
                                return self._hash
1✔
182

183
                            self._hash = hash(tuple(self.dump().items()))
1✔
184
                            return self._hash
1✔
185
                    else:
186
                        __hash__ = None  # type: ignore
1✔
187

188
                if eq:
1✔
189
                    def __eq__(self, other: typing.Self):
1✔
190
                        return type(self) is type(other) and self.dump() == other.dump()
1✔
191

192
                @staticmethod
1✔
193
                def _pickle(variant):
1✔
194
                    assert isinstance(variant, ConstructedVariant)
1✔
195
                    return unpickle, (cls, self.name, (), {name: getattr(variant, name) for name in variant.__fields__})
1✔
196

197
                def dump(self):
1✔
198
                    return {name: getattr(self, name) for name in self.__fields__}
1✔
199

200
                def __repr__(self) -> str:
1✔
201
                    values_repr = ', '.join(f'{name}={getattr(self, f"_{name}" if isinstance(name, int) else name)!r}' for name in self.__fields__)
1✔
202
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
203

204
                def __init__(self, *args, **kwargs) -> None:
1✔
205
                    if args:
1✔
206
                        if item._kw_only:
1✔
207
                            raise TypeError(f"Variant '{type(self).__qualname__}' is keyword only.")
1✔
208

209
                        if len(args) > len(named_field_keys):
1✔
210
                            raise TypeError(f"{self.__name__} takes {len(named_field_keys)} positional argument(s) but {len(args)} were/was given")
1✔
211

212
                        # a valid use case of zip without strict=True
213
                        for arg, field_name in zip(args, named_field_keys):
1✔
214
                            if field_name in kwargs:
1✔
215
                                raise TypeError(f"Inconsistent input for field '{field_name}': received both positional and keyword values")
1✔
216
                            kwargs[field_name] = arg
1✔
217

218
                    if item._defaults_and_factories:
1✔
219
                        for name, default_or_factory in item._defaults_and_factories.items():
1✔
220
                            if name not in kwargs:
1✔
221
                                kwargs[name] = factory._produce_from(default_or_factory)
1✔
222

223
                    if missed_keys := kwargs.keys() ^ named_field.keys():
1✔
224
                        raise TypeError(f"Key mismatch: {missed_keys}")
1✔
225

226
                    for name in named_field:
1✔
227
                        value = kwargs[name]
1✔
228
                        # field = named_field[name]
229
                        setattr(self, name, value)
1✔
230

231
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
232
                    post_init()
1✔
233

234
            self._actual = NamedConstructedVariant
1✔
235

236
        else:
237
            class FieldlessConstructedVariant(ConstructedVariant, metaclass=ParamlessSingletonMeta):
1✔
238
                __name__ = item.name
1✔
239
                __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
240
                __fields__ = ()
1✔
241
                __slots__ = ()
1✔
242

243
                if build_hash and not frozen:
1✔
244
                    __hash__ = None  # type: ignore
1✔
245
                else:
246
                    def __hash__(self):
1✔
247
                        return hash(id(self))
1✔
248

249
                @staticmethod
1✔
250
                def _pickle(variant):
1✔
251
                    assert isinstance(variant, ConstructedVariant)
1✔
252
                    return unpickle, (cls, self.name, (), {})
1✔
253

254
                def dump(self):
1✔
255
                    return ()
1✔
256

257
                def __repr__(self) -> str:
1✔
258
                    values_repr = ""
1✔
259
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
260

261
                def __init__(self) -> None:
1✔
262
                    post_init = getattr(self, "__post_init__", lambda: None)
1✔
263
                    post_init()
1✔
264

265
            self._actual = FieldlessConstructedVariant
1✔
266
        # fmt: on
267

268
        copyreg.pickle(self._actual, self._actual._pickle)
1✔
269
        self.attached = True
1✔
270

271
    def __call__(self, *args, **kwargs):
1✔
UNCOV
272
        return self._actual(*args, **kwargs)
×
273

274

275
POSITIONALS = (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
1✔
276

277

278
class _FunctionVariant(Variant):  # MARK: FunctionVariant
1✔
279
    __slots__ = ("_func", "_signature", "_match_args", "_self_included")
1✔
280
    name: str
1✔
281

282
    def __init__(self, func: types.FunctionType) -> None:
1✔
283
        assert type(func) is types.FunctionType, "Type other than function is not allowed."
1✔
284
        self.attached = False
1✔
285
        self._func = func
1✔
286
        signature = inspect.signature(func)
1✔
287
        parameters_raw = signature.parameters
1✔
288
        self._signature = signature
1✔
289

290
        parameters_iter = iter(parameters_raw)
1✔
291
        self._self_included = next(parameters_iter) == "self"
1✔
292

293
        parameter_names = tuple(parameters_iter) if self._self_included else tuple(parameters_raw)
1✔
294
        self._slots_names = parameter_names
1✔
295
        self.field = ((), parameter_names)
1✔
296
        self._match_args = tuple(
1✔
297
            name for name in parameter_names
298
            if parameters_raw[name].kind in POSITIONALS
299
        )
300

301
    def kw_only(self) -> typing.NoReturn:
1✔
302
        raise TypeError(
1✔
303
            "`.kw_only()` method cannot be used in function variant. "
304
            "Use function keyword-only specifier(asterisk) instead."
305
        )
306

307
    def default(self, **_) -> typing.NoReturn:
1✔
308
        raise TypeError(
1✔
309
            "`.default()` method cannot be used in function variant. "
310
            "Use function defaults instead."
311
        )
312

313
    def attach(
1✔
314
        self,
315
        cls,
316
        /,
317
        *,
318
        eq: bool,
319
        build_hash: bool,
320
        build_repr: bool,
321
        frozen: bool,
322
    ) -> None | typing.Self:
323
        if self.attached:
1✔
UNCOV
324
            raise TypeError(f"This variants already attached to {self._base.__name__!r}.")
×
325

326
        self._base = cls
1✔
327
        item = self
1✔
328

329
        # fmt: off
330
        class ConstructedVariant(cls):
1✔
331
            if frozen and not typing.TYPE_CHECKING:
1✔
332
                __slots__ = tuple(f"__original_{name}" for name in item._slots_names)
1✔
333
                for name in item._slots_names:
1✔
334
                    # to prevent potential security risk
335
                    if name.isidentifier():
1✔
336
                        exec(f"{name} = OneTimeSetter()")
1✔
337
                    else:
338
                        unreachable(name)
UNCOV
339
                        OneTimeSetter()  # Show IDEs that OneTimeSetter is used. Not executed at runtime.
×
340
            else:
341
                __slots__ = item._slots_names
1✔
342

343
            __name__ = item.name
1✔
344
            __qualname__ = f"{cls.__qualname__}.{item.name}"
1✔
345
            __fields__ = item._slots_names
1✔
346
            __match_args__ = item._match_args
1✔
347

348
            if build_hash:
1✔
349
                __slots__ += ("_hash",)
1✔
350

351
                if frozen:
1✔
352
                    def __hash__(self) -> int:
1✔
353
                        with suppress(AttributeError):
1✔
354
                            return self._hash
1✔
355

356
                        self._hash = hash(tuple(self.dump().items()))
1✔
357
                        return self._hash
1✔
358
                else:
359
                    __hash__ = None  # type: ignore
1✔
360

361
            if eq:
1✔
362
                def __eq__(self, other: typing.Self):
1✔
UNCOV
363
                    return type(self) is type(other) and self.dump() == other.dump()
×
364

365
            def _get_positions(self) -> tuple[dict[str, typing.Any], dict[str, typing.Any]]:
1✔
366
                match_args = self.__match_args__
1✔
367
                args_dict = {}
1✔
368
                kwargs = {}
1✔
369
                for name in self.__fields__:
1✔
370
                    if name in match_args:
1✔
371
                        args_dict[name] = getattr(self, name)
1✔
372
                    else:
373
                        kwargs[name] = getattr(self, name)
1✔
374
                return args_dict, kwargs
1✔
375

376
            @staticmethod
1✔
377
            def _pickle(variant):
1✔
378
                assert isinstance(variant, ConstructedVariant)
1✔
379
                args_dict, kwargs = variant._get_positions()
1✔
380
                return unpickle, (cls, self.name, tuple(args_dict.values()), kwargs)
1✔
381

382
            def dump(self):
1✔
383
                return {name: getattr(self, name) for name in self.__fields__}
1✔
384

385
            if eq:
1✔
386
                def __eq__(self, other: typing.Self):
1✔
387
                    return type(self) is type(other) and self.dump() == other.dump()
1✔
388

389
            if build_repr:
1✔
390
                def __repr__(self) -> str:
1✔
391
                    args_dict, kwargs = self._get_positions()
1✔
392
                    args_repr = ", ".join(repr(value) for value in args_dict.values())
1✔
393
                    kwargs_repr = ", ".join(
1✔
394
                        f'{name}={value!r}'
395
                        for name, value in kwargs.items()
396
                    )
397

398
                    if args_repr and kwargs_repr:
1✔
399
                        values_repr = f"{args_repr}, {kwargs_repr}"
1✔
400
                    else:
401
                        values_repr = f"{args_repr}{kwargs_repr}"
1✔
402

403
                    return f"{item._base.__name__}.{self.__name__}({values_repr})"
1✔
404

405
            def __init__(self, *args, **kwargs) -> None:
1✔
406
                bound = (
1✔
407
                    item._signature.bind(None, *args, **kwargs)
408
                    if item._self_included
409
                    else item._signature.bind(*args, **kwargs)
410
                )
411

412
                # code from Signature.apply_defaults()
413
                arguments = bound.arguments
1✔
414
                new_arguments = {}
1✔
415
                for name, param in item._signature.parameters.items():
1✔
416
                    try:
1✔
417
                        value = arguments[name]
1✔
418
                    except KeyError:
1✔
419
                        assert param.default is not inspect._empty, "Argument is not properly bound."
1✔
420
                        value = param.default
1✔
421
                    new_arguments[name] = factory._produce_from(value)
1✔
422
                bound.arguments = new_arguments  # type: ignore # why not OrderedDict? I don't know
1✔
423

424
                for name, value in new_arguments.items():
1✔
425
                    if item._self_included and name == "self":
1✔
426
                        continue
1✔
427
                    setattr(self, name, value)
1✔
428

429
                if item._self_included:
1✔
430
                    bound.arguments["self"] = self
1✔
431
                    result = item._func(*bound.args, **bound.kwargs)
1✔
432
                    if result is not None:
1✔
433
                        raise TypeError("Initializer should return None.")
1✔
434
        # fmt: on
435

436
        self._actual = ConstructedVariant
1✔
437
        copyreg.pickle(self._actual, self._actual._pickle)
1✔
438
        self.attached = True
1✔
439

440

441
@typing.overload
1✔
442
def variant(cls: type, /) -> Variant: ...
1✔
443

444
@typing.overload
1✔
445
def variant(*, kw_only: bool = False) -> typing.Callable[[type], Variant]: ...
1✔
446

447
@typing.overload
1✔
448
def variant(func: types.FunctionType, /) -> Variant: ...
1✔
449

450
def variant(cls_or_func=None, /, *, kw_only: bool = False) -> typing.Any:  # MARK: variant
1✔
451
    if cls_or_func is None:
1✔
452
        return lambda cls_or_func: variant(cls_or_func, kw_only=kw_only)  # type: ignore
1✔
453

454
    if isinstance(cls_or_func, types.FunctionType):
1✔
455
        constructed = _FunctionVariant(cls_or_func)
1✔
456

457
    else:
458
        fields = cls_or_func.__annotations__
1✔
459
        defaults = {
1✔
460
            field_name: getattr(cls_or_func, field_name)
461
            for field_name in fields if hasattr(cls_or_func, field_name)
462
        }
463

464
        constructed = Variant(**fields).default(**defaults)
1✔
465
        if kw_only:
1✔
466
            constructed = constructed.kw_only()
1✔
467

468
    return constructed
1✔
469

470

471
class factory(typing.Generic[T]):  # MARK: factory
1✔
472
    def __init__(self, func: typing.Callable[[], T]):
1✔
473
        self.__factory = func
1✔
474

475
    @classmethod
1✔
476
    def _produce_from(cls, value: factory[T] | T) -> T:
1✔
477
        return value.produce() if isinstance(value, factory) else value  # type: ignore
1✔
478

479
    def produce(self) -> T:
1✔
480
        return self.__factory()
1✔
481

482

483
class UnitDescriptor:  # MARK: Unit
1✔
484
    __slots__ = ("name",)
1✔
485
    __fields__ = None
1✔
486

487
    def __init__(self, name: str | None = None):
1✔
488
        self.name = name
1✔
489

490
    def __set_name__(self, owner, name):
1✔
491
        setattr(owner, name, UnitDescriptor(name))
1✔
492

493
    @typing.overload
1✔
494
    def __get__(self, obj, objtype: type[T] = ...) -> T: ...  # type: ignore
1✔
495

496
    @typing.overload
1✔
497
    def __get__(self, obj, objtype: None = ...) -> typing.Self: ...
1✔
498

499
    def __get__(self, obj, objtype: type[T] | None = None) -> T | typing.Self:
1✔
UNCOV
500
        return self
×
501

502
    def attach(
1✔
503
        self,
504
        cls,
505
        /,
506
        *,
507
        eq: bool,  # not needed since nothing to check equality
508
        build_hash: bool,
509
        build_repr: bool,
510
        frozen: bool,
511
    ):
512
        if self.name is None:
1✔
513
            raise TypeError("`self.name` is not set.")
1✔
514

515
        class UnitConstructedVariant(cls, metaclass=ParamlessSingletonMeta):
1✔
516
            __name__ = self.name
1✔
517
            __slots__ = ()
1✔
518
            __fields__ = None  # `None` means it does not require calling for initialize.
1✔
519

520
            if build_hash and not frozen:
1✔
521
                __hash__ = None  # type: ignore # Explicitly disable hash
1✔
522
            else:
523
                def __hash__(self):
1✔
524
                    return hash(id(self))
1✔
525

526
            def dump(self):
1✔
527
                return None
1✔
528

529
            @staticmethod
1✔
530
            def _pickle(variant):
1✔
531
                assert isinstance(variant, UnitConstructedVariant)
1✔
532
                return unpickle, (cls, self.name, None, None)
1✔
533

534
            def __init__(self):
1✔
535
                pass
1✔
536

537
            if build_repr:
1✔
538
                def __repr__(self):
1✔
539
                    return f"{cls.__name__}.{self.__name__}"
1✔
540

541
        copyreg.pickle(UnitConstructedVariant, UnitConstructedVariant._pickle)
1✔
542

543
        # This will replace Unit to Specialized instance.
544
        setattr(cls, self.name, UnitConstructedVariant())
1✔
545

546

547
Unit = UnitDescriptor()
1✔
548

549

550
def fieldenum(
1✔
551
    cls=None,
552
    /,
553
    *,
554
    eq: bool = True,
555
    frozen: bool = True,
556
):
557
    if cls is None:
1✔
558
        return lambda cls: fieldenum(
1✔
559
            cls,
560
            eq=eq,
561
            frozen=frozen,
562
        )
563

564
    # Preventing subclassing fieldenums at runtime.
565
    # This also prevent double decoration.
566
    is_final = False
1✔
567
    for base in cls.mro()[1:]:
1✔
568
        with suppress(Exception):
1✔
569
            if base.__final__:
1✔
570
                is_final = True
1✔
571
                break
1✔
572
    if is_final:
1✔
573
        raise TypeError(
1✔
574
            "One of the base classes of fieldenum class is marked as final, "
575
            "which means it does not want to be subclassed and it may be fieldenum class, "
576
            "which should not be subclassed."
577
        )
578

579
    class_attributes = vars(cls)
1✔
580
    has_own_hash = "__hash__" in class_attributes
1✔
581
    build_hash = eq and not has_own_hash
1✔
582
    build_repr = cls.__repr__ is object.__repr__
1✔
583

584
    attrs = []
1✔
585
    for name, attr in class_attributes.items():
1✔
586
        if isinstance(attr, Variant | UnitDescriptor):
1✔
587
            attr.attach(
1✔
588
                cls,
589
                eq=eq,
590
                build_hash=build_hash,
591
                build_repr=build_repr,
592
                frozen=frozen,
593
            )
594
            attrs.append(name)
1✔
595

596
    cls.__variants__ = attrs
1✔
597
    cls.__init__ = NotAllowed("A base fieldenum cannot be initialized.", name="__init__")
1✔
598

599
    return typing.final(cls)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc