• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ets-labs / python-dependency-injector / 9377842010

05 Jun 2024 03:41AM UTC coverage: 0.0% (-92.0%) from 92.016%
9377842010

Pull #765

github

rmk135
Remove pypy 3.9
Pull Request #765: Add Python 3.12 Support (#752)

0 of 764 relevant lines covered (0.0%)

0.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/dependency_injector/wiring.py
1
"""Wiring module."""
2

3
import functools
×
4
import inspect
×
5
import importlib
×
6
import importlib.machinery
×
7
import pkgutil
×
8
import warnings
×
9
import sys
×
10
from types import ModuleType
×
11
from typing import (
×
12
    Optional,
13
    Iterable,
14
    Iterator,
15
    Callable,
16
    Any,
17
    Tuple,
18
    Dict,
19
    Generic,
20
    TypeVar,
21
    Type,
22
    Union,
23
    Set,
24
    cast,
25
)
26

27
if sys.version_info < (3, 7):
×
28
    from typing import GenericMeta
×
29
else:
30
    class GenericMeta(type):
×
31
        ...
×
32

33
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
34
if sys.version_info >= (3, 9):
×
35
    from types import GenericAlias
×
36
else:
37
    GenericAlias = None
×
38

39

40
try:
×
41
    import fastapi.params
×
42
except ImportError:
×
43
    fastapi = None
×
44

45

46
try:
×
47
    import starlette.requests
×
48
except ImportError:
×
49
    starlette = None
×
50

51

52
try:
×
53
    import werkzeug.local
×
54
except ImportError:
×
55
    werkzeug = None
×
56

57

58
from . import providers
×
59

60
if sys.version_info[:2] == (3, 5):
×
61
    warnings.warn(
×
62
        "Dependency Injector will drop support of Python 3.5 after Jan 1st of 2022. "
63
        "This does not mean that there will be any immediate breaking changes, "
64
        "but tests will no longer be executed on Python 3.5, and bugs will not be addressed.",
65
        category=DeprecationWarning,
66
    )
67

68
__all__ = (
×
69
    "wire",
70
    "unwire",
71
    "inject",
72
    "as_int",
73
    "as_float",
74
    "as_",
75
    "required",
76
    "invariant",
77
    "provided",
78
    "Provide",
79
    "Provider",
80
    "Closing",
81
    "register_loader_containers",
82
    "unregister_loader_containers",
83
    "install_loader",
84
    "uninstall_loader",
85
    "is_loader_installed",
86
)
87

88
T = TypeVar("T")
×
89
F = TypeVar("F", bound=Callable[..., Any])
×
90
Container = Any
×
91

92

93
class PatchedRegistry:
×
94

95
    def __init__(self) -> None:
×
96
        self._callables: Dict[Callable[..., Any], "PatchedCallable"] = {}
×
97
        self._attributes: Set[PatchedAttribute] = set()
×
98

99
    def register_callable(self, patched: "PatchedCallable") -> None:
×
100
        self._callables[patched.patched] = patched
×
101

102
    def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
×
103
        for patched_callable in self._callables.values():
×
104
            if not patched_callable.is_in_module(module):
×
105
                continue
×
106
            yield patched_callable.patched
×
107

108
    def get_callable(self, fn: Callable[..., Any]) -> "PatchedCallable":
×
109
        return self._callables.get(fn)
×
110

111
    def has_callable(self, fn: Callable[..., Any]) -> bool:
×
112
        return fn in self._callables
×
113

114
    def register_attribute(self, patched: "PatchedAttribute") -> None:
×
115
        self._attributes.add(patched)
×
116

117
    def get_attributes_from_module(self, module: ModuleType) -> Iterator["PatchedAttribute"]:
×
118
        for attribute in self._attributes:
×
119
            if not attribute.is_in_module(module):
×
120
                continue
×
121
            yield attribute
×
122

123
    def clear_module_attributes(self, module: ModuleType) -> None:
×
124
        for attribute in self._attributes.copy():
×
125
            if not attribute.is_in_module(module):
×
126
                continue
×
127
            self._attributes.remove(attribute)
×
128

129

130
class PatchedCallable:
×
131

132
    __slots__ = (
×
133
        "patched",
134
        "original",
135
        "reference_injections",
136
        "injections",
137
        "reference_closing",
138
        "closing",
139
    )
140

141
    def __init__(
×
142
            self,
143
            patched: Optional[Callable[..., Any]] = None,
144
            original: Optional[Callable[..., Any]] = None,
145
            reference_injections: Optional[Dict[Any, Any]] = None,
146
            reference_closing: Optional[Dict[Any, Any]] = None,
147
    ) -> None:
148
        self.patched = patched
×
149
        self.original = original
×
150

151
        if reference_injections is None:
×
152
            reference_injections = {}
×
153
        self.reference_injections: Dict[Any, Any] = reference_injections.copy()
×
154
        self.injections: Dict[Any, Any] = {}
×
155

156
        if reference_closing is None:
×
157
            reference_closing = {}
×
158
        self.reference_closing: Dict[Any, Any] = reference_closing.copy()
×
159
        self.closing: Dict[Any, Any] = {}
×
160

161
    def is_in_module(self, module: ModuleType) -> bool:
×
162
        if self.patched is None:
×
163
            return False
×
164
        return self.patched.__module__ == module.__name__
×
165

166
    def add_injection(self, kwarg: Any, injection: Any) -> None:
×
167
        self.injections[kwarg] = injection
×
168

169
    def add_closing(self, kwarg: Any, injection: Any) -> None:
×
170
        self.closing[kwarg] = injection
×
171

172
    def unwind_injections(self) -> None:
×
173
        self.injections = {}
×
174
        self.closing = {}
×
175

176

177
class PatchedAttribute:
×
178

179
    __slots__ = (
×
180
        "member",
181
        "name",
182
        "marker",
183
    )
184

185
    def __init__(self, member: Any, name: str, marker: "_Marker") -> None:
×
186
        self.member = member
×
187
        self.name = name
×
188
        self.marker = marker
×
189

190
    @property
×
191
    def module_name(self) -> str:
×
192
        if isinstance(self.member, ModuleType):
×
193
            return self.member.__name__
×
194
        else:
195
            return self.member.__module__
×
196

197
    def is_in_module(self, module: ModuleType) -> bool:
×
198
        return self.module_name == module.__name__
×
199

200

201
class ProvidersMap:
×
202

203
    CONTAINER_STRING_ID = "<container>"
×
204

205
    def __init__(self, container) -> None:
×
206
        self._container = container
×
207
        self._map = self._create_providers_map(
×
208
            current_container=container,
209
            original_container=(
210
                container.declarative_parent
211
                if container.declarative_parent
212
                else container
213
            ),
214
        )
215

216
    def resolve_provider(
×
217
            self,
218
            provider: Union[providers.Provider, str],
219
            modifier: Optional["Modifier"] = None,
220
    ) -> Optional[providers.Provider]:
221
        if isinstance(provider, providers.Delegate):
×
222
            return self._resolve_delegate(provider)
×
223
        elif isinstance(provider, (
×
224
            providers.ProvidedInstance,
225
            providers.AttributeGetter,
226
            providers.ItemGetter,
227
            providers.MethodCaller,
228
        )):
229
            return self._resolve_provided_instance(provider)
×
230
        elif isinstance(provider, providers.ConfigurationOption):
×
231
            return self._resolve_config_option(provider)
×
232
        elif isinstance(provider, providers.TypedConfigurationOption):
×
233
            return self._resolve_config_option(provider.option, as_=provider.provides)
×
234
        elif isinstance(provider, str):
×
235
            return self._resolve_string_id(provider, modifier)
×
236
        else:
237
            return self._resolve_provider(provider)
×
238

239
    def _resolve_string_id(
×
240
            self,
241
            id: str,
242
            modifier: Optional["Modifier"] = None,
243
    ) -> Optional[providers.Provider]:
244
        if id == self.CONTAINER_STRING_ID:
×
245
            return self._container.__self__
×
246

247
        provider = self._container
×
248
        for segment in id.split("."):
×
249
            try:
×
250
                provider = getattr(provider, segment)
×
251
            except AttributeError:
×
252
                return None
×
253

254
        if modifier:
×
255
            provider = modifier.modify(provider, providers_map=self)
×
256
        return provider
×
257

258
    def _resolve_provided_instance(
×
259
            self,
260
            original: providers.Provider,
261
    ) -> Optional[providers.Provider]:
262
        modifiers = []
×
263
        while isinstance(original, (
×
264
                providers.ProvidedInstance,
265
                providers.AttributeGetter,
266
                providers.ItemGetter,
267
                providers.MethodCaller,
268
        )):
269
            modifiers.insert(0, original)
×
270
            original = original.provides
×
271

272
        new = self._resolve_provider(original)
×
273
        if new is None:
×
274
            return None
×
275

276
        for modifier in modifiers:
×
277
            if isinstance(modifier, providers.ProvidedInstance):
×
278
                new = new.provided
×
279
            elif isinstance(modifier, providers.AttributeGetter):
×
280
                new = getattr(new, modifier.name)
×
281
            elif isinstance(modifier, providers.ItemGetter):
×
282
                new = new[modifier.name]
×
283
            elif isinstance(modifier, providers.MethodCaller):
×
284
                new = new.call(
×
285
                    *modifier.args,
286
                    **modifier.kwargs,
287
                )
288

289
        return new
×
290

291
    def _resolve_delegate(
×
292
            self,
293
            original: providers.Delegate,
294
    ) -> Optional[providers.Provider]:
295
        provider = self._resolve_provider(original.provides)
×
296
        if provider:
×
297
            provider = provider.provider
×
298
        return provider
×
299

300
    def _resolve_config_option(
×
301
            self,
302
            original: providers.ConfigurationOption,
303
            as_: Any = None,
304
    ) -> Optional[providers.Provider]:
305
        original_root = original.root
×
306
        new = self._resolve_provider(original_root)
×
307
        if new is None:
×
308
            return None
×
309
        new = cast(providers.Configuration, new)
×
310

311
        for segment in original.get_name_segments():
×
312
            if providers.is_provider(segment):
×
313
                segment = self.resolve_provider(segment)
×
314
                new = new[segment]
×
315
            else:
316
                new = getattr(new, segment)
×
317

318
        if original.is_required():
×
319
            new = new.required()
×
320

321
        if as_:
×
322
            new = new.as_(as_)
×
323

324
        return new
×
325

326
    def _resolve_provider(
×
327
            self,
328
            original: providers.Provider,
329
    ) -> Optional[providers.Provider]:
330
        try:
×
331
            return self._map[original]
×
332
        except KeyError:
×
333
            return None
×
334

335
    @classmethod
×
336
    def _create_providers_map(
×
337
            cls,
338
            current_container: Container,
339
            original_container: Container,
340
    ) -> Dict[providers.Provider, providers.Provider]:
341
        current_providers = current_container.providers
×
342
        current_providers["__self__"] = current_container.__self__
×
343

344
        original_providers = original_container.providers
×
345
        original_providers["__self__"] = original_container.__self__
×
346

347
        providers_map = {}
×
348
        for provider_name, current_provider in current_providers.items():
×
349
            original_provider = original_providers[provider_name]
×
350
            providers_map[original_provider] = current_provider
×
351

352
            if isinstance(current_provider, providers.Container) \
×
353
                    and isinstance(original_provider, providers.Container):
354
                subcontainer_map = cls._create_providers_map(
×
355
                    current_container=current_provider.container,
356
                    original_container=original_provider.container,
357
                )
358
                providers_map.update(subcontainer_map)
×
359

360
        return providers_map
×
361

362

363
class InspectFilter:
×
364

365
    def is_excluded(self, instance: object) -> bool:
×
366
        if self._is_werkzeug_local_proxy(instance):
×
367
            return True
×
368
        elif self._is_starlette_request_cls(instance):
×
369
            return True
×
370
        elif self._is_builtin(instance):
×
371
            return True
×
372
        else:
373
            return False
×
374

375
    def _is_werkzeug_local_proxy(self, instance: object) -> bool:
×
376
        return werkzeug and isinstance(instance, werkzeug.local.LocalProxy)
×
377

378
    def _is_starlette_request_cls(self, instance: object) -> bool:
×
379
        return starlette \
×
380
               and isinstance(instance, type) \
381
               and _safe_is_subclass(instance, starlette.requests.Request)
382

383
    def _is_builtin(self, instance: object) -> bool:
×
384
        return inspect.isbuiltin(instance)
×
385

386

387
def wire(  # noqa: C901
×
388
        container: Container,
389
        *,
390
        modules: Optional[Iterable[ModuleType]] = None,
391
        packages: Optional[Iterable[ModuleType]] = None,
392
) -> None:
393
    """Wire container providers with provided packages and modules."""
394
    modules = [*modules] if modules else []
×
395

396
    if packages:
×
397
        for package in packages:
×
398
            modules.extend(_fetch_modules(package))
×
399

400
    providers_map = ProvidersMap(container)
×
401

402
    for module in modules:
×
403
        for member_name, member in inspect.getmembers(module):
×
404
            if _inspect_filter.is_excluded(member):
×
405
                continue
×
406

407
            if _is_marker(member):
×
408
                _patch_attribute(module, member_name, member, providers_map)
×
409
            elif inspect.isfunction(member):
×
410
                _patch_fn(module, member_name, member, providers_map)
×
411
            elif inspect.isclass(member):
×
412
                cls = member
×
413
                try:
×
414
                    cls_members = inspect.getmembers(cls)
×
415
                except Exception:  # noqa
×
416
                    # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/441
417
                    continue
×
418
                else:
419
                    for cls_member_name, cls_member in cls_members:
×
420
                        if _is_marker(cls_member):
×
421
                            _patch_attribute(cls, cls_member_name, cls_member, providers_map)
×
422
                        elif _is_method(cls_member):
×
423
                            _patch_method(cls, cls_member_name, cls_member, providers_map)
×
424

425
        for patched in _patched_registry.get_callables_from_module(module):
×
426
            _bind_injections(patched, providers_map)
×
427

428

429
def unwire(  # noqa: C901
×
430
        *,
431
        modules: Optional[Iterable[ModuleType]] = None,
432
        packages: Optional[Iterable[ModuleType]] = None,
433
) -> None:
434
    """Wire provided packages and modules with previous wired providers."""
435
    modules = [*modules] if modules else []
×
436

437
    if packages:
×
438
        for package in packages:
×
439
            modules.extend(_fetch_modules(package))
×
440

441
    for module in modules:
×
442
        for name, member in inspect.getmembers(module):
×
443
            if inspect.isfunction(member):
×
444
                _unpatch(module, name, member)
×
445
            elif inspect.isclass(member):
×
446
                for method_name, method in inspect.getmembers(member, inspect.isfunction):
×
447
                    _unpatch(member, method_name, method)
×
448

449
        for patched in _patched_registry.get_callables_from_module(module):
×
450
            _unbind_injections(patched)
×
451

452
        for patched_attribute in _patched_registry.get_attributes_from_module(module):
×
453
            _unpatch_attribute(patched_attribute)
×
454
        _patched_registry.clear_module_attributes(module)
×
455

456

457
def inject(fn: F) -> F:
×
458
    """Decorate callable with injecting decorator."""
459
    reference_injections, reference_closing = _fetch_reference_injections(fn)
×
460
    patched = _get_patched(fn, reference_injections, reference_closing)
×
461
    return cast(F, patched)
×
462

463

464
def _patch_fn(
×
465
        module: ModuleType,
466
        name: str,
467
        fn: Callable[..., Any],
468
        providers_map: ProvidersMap,
469
) -> None:
470
    if not _is_patched(fn):
×
471
        reference_injections, reference_closing = _fetch_reference_injections(fn)
×
472
        if not reference_injections:
×
473
            return
×
474
        fn = _get_patched(fn, reference_injections, reference_closing)
×
475

476
    _bind_injections(fn, providers_map)
×
477

478
    setattr(module, name, fn)
×
479

480

481
def _patch_method(
×
482
        cls: Type,
483
        name: str,
484
        method: Callable[..., Any],
485
        providers_map: ProvidersMap,
486
) -> None:
487
    if hasattr(cls, "__dict__") \
×
488
            and name in cls.__dict__ \
489
            and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
490
        method = cls.__dict__[name]
×
491
        fn = method.__func__
×
492
    else:
493
        fn = method
×
494

495
    if not _is_patched(fn):
×
496
        reference_injections, reference_closing = _fetch_reference_injections(fn)
×
497
        if not reference_injections:
×
498
            return
×
499
        fn = _get_patched(fn, reference_injections, reference_closing)
×
500

501
    _bind_injections(fn, providers_map)
×
502

503
    if isinstance(method, (classmethod, staticmethod)):
×
504
        fn = type(method)(fn)
×
505

506
    setattr(cls, name, fn)
×
507

508

509
def _unpatch(
×
510
        module: ModuleType,
511
        name: str,
512
        fn: Callable[..., Any],
513
) -> None:
514
    if hasattr(module, "__dict__") \
×
515
            and name in module.__dict__ \
516
            and isinstance(module.__dict__[name], (classmethod, staticmethod)):
517
        method = module.__dict__[name]
×
518
        fn = method.__func__
×
519

520
    if not _is_patched(fn):
×
521
        return
×
522

523
    _unbind_injections(fn)
×
524

525

526
def _patch_attribute(
×
527
        member: Any,
528
        name: str,
529
        marker: "_Marker",
530
        providers_map: ProvidersMap,
531
) -> None:
532
    provider = providers_map.resolve_provider(marker.provider, marker.modifier)
×
533
    if provider is None:
×
534
        return
×
535

536
    _patched_registry.register_attribute(PatchedAttribute(member, name, marker))
×
537

538
    if isinstance(marker, Provide):
×
539
        instance = provider()
×
540
        setattr(member, name, instance)
×
541
    elif isinstance(marker, Provider):
×
542
        setattr(member, name, provider)
×
543
    else:
544
        raise Exception(f"Unknown type of marker {marker}")
×
545

546

547
def _unpatch_attribute(patched: PatchedAttribute) -> None:
×
548
    setattr(patched.member, patched.name, patched.marker)
×
549

550

551
def _fetch_reference_injections(  # noqa: C901
×
552
        fn: Callable[..., Any],
553
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
554
    # Hotfix, see:
555
    # - https://github.com/ets-labs/python-dependency-injector/issues/362
556
    # - https://github.com/ets-labs/python-dependency-injector/issues/398
557
    if GenericAlias and any((
×
558
                fn is GenericAlias,
559
                getattr(fn, "__func__", None) is GenericAlias
560
            )):
561
        fn = fn.__init__
×
562

563
    try:
×
564
        signature = inspect.signature(fn)
×
565
    except ValueError as exception:
×
566
        if "no signature found" in str(exception):
×
567
            return {}, {}
×
568
        elif "not supported by signature" in str(exception):
×
569
            return {}, {}
×
570
        else:
571
            raise exception
×
572

573
    injections = {}
×
574
    closing = {}
×
575
    for parameter_name, parameter in signature.parameters.items():
×
576
        if not isinstance(parameter.default, _Marker) \
×
577
                and not _is_fastapi_depends(parameter.default):
578
            continue
×
579

580
        marker = parameter.default
×
581

582
        if _is_fastapi_depends(marker):
×
583
            marker = marker.dependency
×
584

585
            if not isinstance(marker, _Marker):
×
586
                continue
×
587

588
        if isinstance(marker, Closing):
×
589
            marker = marker.provider
×
590
            closing[parameter_name] = marker
×
591

592
        injections[parameter_name] = marker
×
593
    return injections, closing
×
594

595

596
def _locate_dependent_closing_args(provider: providers.Provider) -> Dict[str, providers.Provider]:
×
597
    if not hasattr(provider, "args"):
×
598
        return {}
×
599

600
    closing_deps = {}
×
601
    for arg in provider.args:
×
602
        if not isinstance(arg, providers.Provider) or not hasattr(arg, "args"):
×
603
            continue
×
604

605
        if not arg.args and isinstance(arg, providers.Resource):
×
606
            return {str(id(arg)): arg}
×
607
        else:
608
            closing_deps += _locate_dependent_closing_args(arg)
×
609
    return closing_deps
×
610

611

612
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
×
613
    patched_callable = _patched_registry.get_callable(fn)
×
614
    if patched_callable is None:
×
615
        return
×
616

617
    for injection, marker in patched_callable.reference_injections.items():
×
618
        provider = providers_map.resolve_provider(marker.provider, marker.modifier)
×
619

620
        if provider is None:
×
621
            continue
×
622

623
        if isinstance(marker, Provide):
×
624
            patched_callable.add_injection(injection, provider)
×
625
        elif isinstance(marker, Provider):
×
626
            if isinstance(provider, providers.Delegate):
×
627
                patched_callable.add_injection(injection, provider)
×
628
            else:
629
                patched_callable.add_injection(injection, provider.provider)
×
630

631
        if injection in patched_callable.reference_closing:
×
632
            patched_callable.add_closing(injection, provider)
×
633
            deps = _locate_dependent_closing_args(provider)
×
634
            for key, dep in deps.items():
×
635
                patched_callable.add_closing(key, dep)
×
636

637

638
def _unbind_injections(fn: Callable[..., Any]) -> None:
×
639
    patched_callable = _patched_registry.get_callable(fn)
×
640
    if patched_callable is None:
×
641
        return
×
642
    patched_callable.unwind_injections()
×
643

644

645
def _fetch_modules(package):
×
646
    modules = [package]
×
647
    if not hasattr(package, "__path__") or not hasattr(package, "__name__"):
×
648
        return modules
×
649
    for module_info in pkgutil.walk_packages(
×
650
            path=package.__path__,
651
            prefix=package.__name__ + ".",
652
    ):
653
        module = importlib.import_module(module_info.name)
×
654
        modules.append(module)
×
655
    return modules
×
656

657

658
def _is_method(member) -> bool:
×
659
    return inspect.ismethod(member) or inspect.isfunction(member)
×
660

661

662
def _is_marker(member) -> bool:
×
663
    return isinstance(member, _Marker)
×
664

665

666
def _get_patched(
×
667
        fn: F,
668
        reference_injections: Dict[Any, Any],
669
        reference_closing: Dict[Any, Any],
670
) -> F:
671
    patched_object = PatchedCallable(
×
672
        original=fn,
673
        reference_injections=reference_injections,
674
        reference_closing=reference_closing,
675
    )
676

677
    if inspect.iscoroutinefunction(fn):
×
678
        patched = _get_async_patched(fn, patched_object)
×
679
    else:
680
        patched = _get_sync_patched(fn, patched_object)
×
681

682
    patched_object.patched = patched
×
683
    _patched_registry.register_callable(patched_object)
×
684

685
    return patched
×
686

687

688
def _is_fastapi_depends(param: Any) -> bool:
×
689
    return fastapi and isinstance(param, fastapi.params.Depends)
×
690

691

692
def _is_patched(fn) -> bool:
×
693
    return _patched_registry.has_callable(fn)
×
694

695

696
def _is_declarative_container(instance: Any) -> bool:
×
697
    return (isinstance(instance, type)
×
698
            and getattr(instance, "__IS_CONTAINER__", False) is True
699
            and getattr(instance, "declarative_parent", None) is None)
700

701

702
def _safe_is_subclass(instance: Any, cls: Type) -> bool:
×
703
    try:
×
704
        return issubclass(instance, cls)
×
705
    except TypeError:
×
706
        return False
×
707

708

709
class Modifier:
×
710

711
    def modify(
×
712
            self,
713
            provider: providers.ConfigurationOption,
714
            providers_map: ProvidersMap,
715
    ) -> providers.Provider:
716
        ...
×
717

718

719
class TypeModifier(Modifier):
×
720

721
    def __init__(self, type_: Type) -> None:
×
722
        self.type_ = type_
×
723

724
    def modify(
×
725
            self,
726
            provider: providers.ConfigurationOption,
727
            providers_map: ProvidersMap,
728
    ) -> providers.Provider:
729
        return provider.as_(self.type_)
×
730

731

732
def as_int() -> TypeModifier:
×
733
    """Return int type modifier."""
734
    return TypeModifier(int)
×
735

736

737
def as_float() -> TypeModifier:
×
738
    """Return float type modifier."""
739
    return TypeModifier(float)
×
740

741

742
def as_(type_: Type) -> TypeModifier:
×
743
    """Return custom type modifier."""
744
    return TypeModifier(type_)
×
745

746

747
class RequiredModifier(Modifier):
×
748

749
    def __init__(self) -> None:
×
750
        self.type_modifier = None
×
751

752
    def as_int(self) -> "RequiredModifier":
×
753
        self.type_modifier = TypeModifier(int)
×
754
        return self
×
755

756
    def as_float(self) -> "RequiredModifier":
×
757
        self.type_modifier = TypeModifier(float)
×
758
        return self
×
759

760
    def as_(self, type_: Type) -> "RequiredModifier":
×
761
        self.type_modifier = TypeModifier(type_)
×
762
        return self
×
763

764
    def modify(
×
765
            self,
766
            provider: providers.ConfigurationOption,
767
            providers_map: ProvidersMap,
768
    ) -> providers.Provider:
769
        provider = provider.required()
×
770
        if self.type_modifier:
×
771
            provider = provider.as_(self.type_modifier.type_)
×
772
        return provider
×
773

774

775
def required() -> RequiredModifier:
×
776
    """Return required modifier."""
777
    return RequiredModifier()
×
778

779

780
class InvariantModifier(Modifier):
×
781

782
    def __init__(self, id: str) -> None:
×
783
        self.id = id
×
784

785
    def modify(
×
786
            self,
787
            provider: providers.ConfigurationOption,
788
            providers_map: ProvidersMap,
789
    ) -> providers.Provider:
790
        invariant_segment = providers_map.resolve_provider(self.id)
×
791
        return provider[invariant_segment]
×
792

793

794
def invariant(id: str) -> InvariantModifier:
×
795
    """Return invariant modifier."""
796
    return InvariantModifier(id)
×
797

798

799
class ProvidedInstance(Modifier):
×
800

801
    TYPE_ATTRIBUTE = "attr"
×
802
    TYPE_ITEM = "item"
×
803
    TYPE_CALL = "call"
×
804

805
    def __init__(self) -> None:
×
806
        self.segments = []
×
807

808
    def __getattr__(self, item):
×
809
        self.segments.append((self.TYPE_ATTRIBUTE, item))
×
810
        return self
×
811

812
    def __getitem__(self, item):
×
813
        self.segments.append((self.TYPE_ITEM, item))
×
814
        return self
×
815

816
    def call(self):
×
817
        self.segments.append((self.TYPE_CALL, None))
×
818
        return self
×
819

820
    def modify(
×
821
            self,
822
            provider: providers.Provider,
823
            providers_map: ProvidersMap,
824
    ) -> providers.Provider:
825
        provider = provider.provided
×
826
        for type_, value in self.segments:
×
827
            if type_ == ProvidedInstance.TYPE_ATTRIBUTE:
×
828
                provider = getattr(provider, value)
×
829
            elif type_ == ProvidedInstance.TYPE_ITEM:
×
830
                provider = provider[value]
×
831
            elif type_ == ProvidedInstance.TYPE_CALL:
×
832
                provider = provider.call()
×
833
        return provider
×
834

835

836
def provided() -> ProvidedInstance:
×
837
    """Return provided instance modifier."""
838
    return ProvidedInstance()
×
839

840

841
class ClassGetItemMeta(GenericMeta):
×
842
    def __getitem__(cls, item):
×
843
        # Spike for Python 3.6
844
        if isinstance(item, tuple):
×
845
            return cls(*item)
×
846
        return cls(item)
×
847

848

849
class _Marker(Generic[T], metaclass=ClassGetItemMeta):
×
850

851
    __IS_MARKER__ = True
×
852

853
    def __init__(
×
854
            self,
855
            provider: Union[providers.Provider, Container, str],
856
            modifier: Optional[Modifier] = None,
857
    ) -> None:
858
        if _is_declarative_container(provider):
×
859
            provider = provider.__self__
×
860
        self.provider = provider
×
861
        self.modifier = modifier
×
862

863
    def __class_getitem__(cls, item) -> T:
×
864
        if isinstance(item, tuple):
×
865
            return cls(*item)
×
866
        return cls(item)
×
867

868
    def __call__(self) -> T:
×
869
        return self
×
870

871

872
class Provide(_Marker):
×
873
    ...
×
874

875

876
class Provider(_Marker):
×
877
    ...
×
878

879

880
class Closing(_Marker):
×
881
    ...
×
882

883

884
class AutoLoader:
×
885
    """Auto-wiring module loader.
886

887
    Automatically wire containers when modules are imported.
888
    """
889

890
    def __init__(self) -> None:
×
891
        self.containers = []
×
892
        self._path_hook = None
×
893

894
    def register_containers(self, *containers) -> None:
×
895
        self.containers.extend(containers)
×
896

897
        if not self.installed:
×
898
            self.install()
×
899

900
    def unregister_containers(self, *containers) -> None:
×
901
        for container in containers:
×
902
            self.containers.remove(container)
×
903

904
        if not self.containers:
×
905
            self.uninstall()
×
906

907
    def wire_module(self, module) -> None:
×
908
        for container in self.containers:
×
909
            container.wire(modules=[module])
×
910

911
    @property
×
912
    def installed(self) -> bool:
×
913
        return self._path_hook in sys.path_hooks
×
914

915
    def install(self) -> None:
×
916
        if self.installed:
×
917
            return
×
918

919
        loader = self
×
920

921
        class SourcelessFileLoader(importlib.machinery.SourcelessFileLoader):
×
922
            def exec_module(self, module):
×
923
                super().exec_module(module)
×
924
                loader.wire_module(module)
×
925

926
        class SourceFileLoader(importlib.machinery.SourceFileLoader):
×
927
            def exec_module(self, module):
×
928
                super().exec_module(module)
×
929
                loader.wire_module(module)
×
930

931
        class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader):
×
932
            ...
×
933

934
        loader_details = [
×
935
            (SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
936
            (SourceFileLoader, importlib.machinery.SOURCE_SUFFIXES),
937
            (ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES),
938
        ]
939

940
        self._path_hook = importlib.machinery.FileFinder.path_hook(*loader_details)
×
941

942
        sys.path_hooks.insert(0, self._path_hook)
×
943
        sys.path_importer_cache.clear()
×
944
        importlib.invalidate_caches()
×
945

946
    def uninstall(self) -> None:
×
947
        if not self.installed:
×
948
            return
×
949

950
        sys.path_hooks.remove(self._path_hook)
×
951
        sys.path_importer_cache.clear()
×
952
        importlib.invalidate_caches()
×
953

954

955
def register_loader_containers(*containers: Container) -> None:
×
956
    """Register containers in auto-wiring module loader."""
957
    _loader.register_containers(*containers)
×
958

959

960
def unregister_loader_containers(*containers: Container) -> None:
×
961
    """Unregister containers from auto-wiring module loader."""
962
    _loader.unregister_containers(*containers)
×
963

964

965
def install_loader() -> None:
×
966
    """Install auto-wiring module loader hook."""
967
    _loader.install()
×
968

969

970
def uninstall_loader() -> None:
×
971
    """Uninstall auto-wiring module loader hook."""
972
    _loader.uninstall()
×
973

974

975
def is_loader_installed() -> bool:
×
976
    """Check if auto-wiring module loader hook is installed."""
977
    return _loader.installed
×
978

979

980
_patched_registry = PatchedRegistry()
×
981
_inspect_filter = InspectFilter()
×
982
_loader = AutoLoader()
×
983

984
# Optimizations
985
from ._cwiring import _get_sync_patched  # noqa
×
986
from ._cwiring import _async_inject  # noqa
×
987

988

989
# Wiring uses the following Python wrapper because there is
990
# no possibility to compile a first-type citizen coroutine in Cython.
991
def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
×
992
    @functools.wraps(fn)
×
993
    async def _patched(*args, **kwargs):
×
994
        return await _async_inject(
×
995
            fn,
996
            args,
997
            kwargs,
998
            patched.injections,
999
            patched.closing,
1000
        )
1001
    return _patched
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc