• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ets-labs / python-dependency-injector / 3726523887

pending completion
3726523887

Pull #647

github-actions

GitHub
Merge fad248b4c into 3858cef65
Pull Request #647: Python 3.11 Support

691 of 750 relevant lines covered (92.13%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.71
/src/dependency_injector/wiring.py
1
"""Wiring module."""
2

3
import functools
1✔
4
import inspect
1✔
5
import importlib
1✔
6
import importlib.machinery
1✔
7
import pkgutil
1✔
8
import warnings
1✔
9
import sys
1✔
10
from types import ModuleType
1✔
11
from typing import (
1✔
12
    Optional,
13
    Iterable,
14
    Iterator,
15
    Callable,
16
    Any,
17
    Tuple,
18
    Dict,
19
    Generic,
20
    TypeVar,
21
    Type,
22
    Union,
23
    Set,
24
    cast,
25
)
26

27
if sys.version_info < (3, 7):
1✔
28
    from typing import GenericMeta
×
29
else:
30
    class GenericMeta(type):
1✔
31
        ...
1✔
32

33
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
34
if sys.version_info >= (3, 9):
1✔
35
    from types import GenericAlias
1✔
36
else:
37
    GenericAlias = None
×
38

39

40
try:
1✔
41
    import fastapi.params
1✔
42
except ImportError:
×
43
    fastapi = None
×
44

45

46
try:
1✔
47
    import starlette.requests
1✔
48
except ImportError:
×
49
    starlette = None
×
50

51

52
try:
1✔
53
    import werkzeug.local
1✔
54
except ImportError:
×
55
    werkzeug = None
×
56

57

58
from . import providers
1✔
59

60
if sys.version_info[:2] == (3, 5):
1✔
61
    warnings.warn(
×
62
        "Dependency Injector will drop support of Python 3.5 after Jan 1st of 2022. "
63
        "This does not mean that there will be any immediate breaking changes, "
64
        "but tests will no longer be executed on Python 3.5, and bugs will not be addressed.",
65
        category=DeprecationWarning,
66
    )
67

68
__all__ = (
1✔
69
    "wire",
70
    "unwire",
71
    "inject",
72
    "as_int",
73
    "as_float",
74
    "as_",
75
    "required",
76
    "invariant",
77
    "provided",
78
    "Provide",
79
    "Provider",
80
    "Closing",
81
    "register_loader_containers",
82
    "unregister_loader_containers",
83
    "install_loader",
84
    "uninstall_loader",
85
    "is_loader_installed",
86
)
87

88
T = TypeVar("T")
1✔
89
F = TypeVar("F", bound=Callable[..., Any])
1✔
90
Container = Any
1✔
91

92

93
class PatchedRegistry:
1✔
94

95
    def __init__(self) -> None:
1✔
96
        self._callables: Dict[Callable[..., Any], "PatchedCallable"] = {}
1✔
97
        self._attributes: Set[PatchedAttribute] = set()
1✔
98

99
    def register_callable(self, patched: "PatchedCallable") -> None:
1✔
100
        self._callables[patched.patched] = patched
1✔
101

102
    def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
1✔
103
        for patched_callable in self._callables.values():
1✔
104
            if not patched_callable.is_in_module(module):
1✔
105
                continue
1✔
106
            yield patched_callable.patched
1✔
107

108
    def get_callable(self, fn: Callable[..., Any]) -> "PatchedCallable":
1✔
109
        return self._callables.get(fn)
1✔
110

111
    def has_callable(self, fn: Callable[..., Any]) -> bool:
1✔
112
        return fn in self._callables
1✔
113

114
    def register_attribute(self, patched: "PatchedAttribute") -> None:
1✔
115
        self._attributes.add(patched)
1✔
116

117
    def get_attributes_from_module(self, module: ModuleType) -> Iterator["PatchedAttribute"]:
1✔
118
        for attribute in self._attributes:
1✔
119
            if not attribute.is_in_module(module):
1✔
120
                continue
1✔
121
            yield attribute
1✔
122

123
    def clear_module_attributes(self, module: ModuleType) -> None:
1✔
124
        for attribute in self._attributes.copy():
1✔
125
            if not attribute.is_in_module(module):
1✔
126
                continue
1✔
127
            self._attributes.remove(attribute)
1✔
128

129

130
class PatchedCallable:
1✔
131

132
    __slots__ = (
1✔
133
        "patched",
134
        "original",
135
        "reference_injections",
136
        "injections",
137
        "reference_closing",
138
        "closing",
139
    )
140

141
    def __init__(
1✔
142
            self,
143
            patched: Optional[Callable[..., Any]] = None,
144
            original: Optional[Callable[..., Any]] = None,
145
            reference_injections: Optional[Dict[Any, Any]] = None,
146
            reference_closing: Optional[Dict[Any, Any]] = None,
147
    ) -> None:
148
        self.patched = patched
1✔
149
        self.original = original
1✔
150

151
        if reference_injections is None:
1✔
152
            reference_injections = {}
×
153
        self.reference_injections: Dict[Any, Any] = reference_injections.copy()
1✔
154
        self.injections: Dict[Any, Any] = {}
1✔
155

156
        if reference_closing is None:
1✔
157
            reference_closing = {}
×
158
        self.reference_closing: Dict[Any, Any] = reference_closing.copy()
1✔
159
        self.closing: Dict[Any, Any] = {}
1✔
160

161
    def is_in_module(self, module: ModuleType) -> bool:
1✔
162
        if self.patched is None:
1✔
163
            return False
×
164
        return self.patched.__module__ == module.__name__
1✔
165

166
    def add_injection(self, kwarg: Any, injection: Any) -> None:
1✔
167
        self.injections[kwarg] = injection
1✔
168

169
    def add_closing(self, kwarg: Any, injection: Any) -> None:
1✔
170
        self.closing[kwarg] = injection
1✔
171

172
    def unwind_injections(self) -> None:
1✔
173
        self.injections = {}
1✔
174
        self.closing = {}
1✔
175

176

177
class PatchedAttribute:
1✔
178

179
    __slots__ = (
1✔
180
        "member",
181
        "name",
182
        "marker",
183
    )
184

185
    def __init__(self, member: Any, name: str, marker: "_Marker") -> None:
1✔
186
        self.member = member
1✔
187
        self.name = name
1✔
188
        self.marker = marker
1✔
189

190
    @property
1✔
191
    def module_name(self) -> str:
1✔
192
        if isinstance(self.member, ModuleType):
1✔
193
            return self.member.__name__
1✔
194
        else:
195
            return self.member.__module__
1✔
196

197
    def is_in_module(self, module: ModuleType) -> bool:
1✔
198
        return self.module_name == module.__name__
1✔
199

200

201
class ProvidersMap:
1✔
202

203
    CONTAINER_STRING_ID = "<container>"
1✔
204

205
    def __init__(self, container) -> None:
1✔
206
        self._container = container
1✔
207
        self._map = self._create_providers_map(
1✔
208
            current_container=container,
209
            original_container=(
210
                container.declarative_parent
211
                if container.declarative_parent
212
                else container
213
            ),
214
        )
215

216
    def resolve_provider(
1✔
217
            self,
218
            provider: Union[providers.Provider, str],
219
            modifier: Optional["Modifier"] = None,
220
    ) -> Optional[providers.Provider]:
221
        if isinstance(provider, providers.Delegate):
1✔
222
            return self._resolve_delegate(provider)
1✔
223
        elif isinstance(provider, (
1✔
224
            providers.ProvidedInstance,
225
            providers.AttributeGetter,
226
            providers.ItemGetter,
227
            providers.MethodCaller,
228
        )):
229
            return self._resolve_provided_instance(provider)
1✔
230
        elif isinstance(provider, providers.ConfigurationOption):
1✔
231
            return self._resolve_config_option(provider)
1✔
232
        elif isinstance(provider, providers.TypedConfigurationOption):
1✔
233
            return self._resolve_config_option(provider.option, as_=provider.provides)
1✔
234
        elif isinstance(provider, str):
1✔
235
            return self._resolve_string_id(provider, modifier)
1✔
236
        else:
237
            return self._resolve_provider(provider)
1✔
238

239
    def _resolve_string_id(
1✔
240
            self,
241
            id: str,
242
            modifier: Optional["Modifier"] = None,
243
    ) -> Optional[providers.Provider]:
244
        if id == self.CONTAINER_STRING_ID:
1✔
245
            return self._container.__self__
1✔
246

247
        provider = self._container
1✔
248
        for segment in id.split("."):
1✔
249
            try:
1✔
250
                provider = getattr(provider, segment)
1✔
251
            except AttributeError:
1✔
252
                return None
1✔
253

254
        if modifier:
1✔
255
            provider = modifier.modify(provider, providers_map=self)
1✔
256
        return provider
1✔
257

258
    def _resolve_provided_instance(
1✔
259
            self,
260
            original: providers.Provider,
261
    ) -> Optional[providers.Provider]:
262
        modifiers = []
1✔
263
        while isinstance(original, (
1✔
264
                providers.ProvidedInstance,
265
                providers.AttributeGetter,
266
                providers.ItemGetter,
267
                providers.MethodCaller,
268
        )):
269
            modifiers.insert(0, original)
1✔
270
            original = original.provides
1✔
271

272
        new = self._resolve_provider(original)
1✔
273
        if new is None:
1✔
274
            return None
1✔
275

276
        for modifier in modifiers:
1✔
277
            if isinstance(modifier, providers.ProvidedInstance):
1✔
278
                new = new.provided
1✔
279
            elif isinstance(modifier, providers.AttributeGetter):
1✔
280
                new = getattr(new, modifier.name)
1✔
281
            elif isinstance(modifier, providers.ItemGetter):
1✔
282
                new = new[modifier.name]
1✔
283
            elif isinstance(modifier, providers.MethodCaller):
1✔
284
                new = new.call(
1✔
285
                    *modifier.args,
286
                    **modifier.kwargs,
287
                )
288

289
        return new
1✔
290

291
    def _resolve_delegate(
1✔
292
            self,
293
            original: providers.Delegate,
294
    ) -> Optional[providers.Provider]:
295
        provider = self._resolve_provider(original.provides)
1✔
296
        if provider:
1✔
297
            provider = provider.provider
1✔
298
        return provider
1✔
299

300
    def _resolve_config_option(
1✔
301
            self,
302
            original: providers.ConfigurationOption,
303
            as_: Any = None,
304
    ) -> Optional[providers.Provider]:
305
        original_root = original.root
1✔
306
        new = self._resolve_provider(original_root)
1✔
307
        if new is None:
1✔
308
            return None
1✔
309
        new = cast(providers.Configuration, new)
1✔
310

311
        for segment in original.get_name_segments():
1✔
312
            if providers.is_provider(segment):
1✔
313
                segment = self.resolve_provider(segment)
1✔
314
                new = new[segment]
1✔
315
            else:
316
                new = getattr(new, segment)
1✔
317

318
        if original.is_required():
1✔
319
            new = new.required()
1✔
320

321
        if as_:
1✔
322
            new = new.as_(as_)
1✔
323

324
        return new
1✔
325

326
    def _resolve_provider(
1✔
327
            self,
328
            original: providers.Provider,
329
    ) -> Optional[providers.Provider]:
330
        try:
1✔
331
            return self._map[original]
1✔
332
        except KeyError:
1✔
333
            return None
1✔
334

335
    @classmethod
1✔
336
    def _create_providers_map(
1✔
337
            cls,
338
            current_container: Container,
339
            original_container: Container,
340
    ) -> Dict[providers.Provider, providers.Provider]:
341
        current_providers = current_container.providers
1✔
342
        current_providers["__self__"] = current_container.__self__
1✔
343

344
        original_providers = original_container.providers
1✔
345
        original_providers["__self__"] = original_container.__self__
1✔
346

347
        providers_map = {}
1✔
348
        for provider_name, current_provider in current_providers.items():
1✔
349
            original_provider = original_providers[provider_name]
1✔
350
            providers_map[original_provider] = current_provider
1✔
351

352
            if isinstance(current_provider, providers.Container) \
1✔
353
                    and isinstance(original_provider, providers.Container):
354
                subcontainer_map = cls._create_providers_map(
1✔
355
                    current_container=current_provider.container,
356
                    original_container=original_provider.container,
357
                )
358
                providers_map.update(subcontainer_map)
1✔
359

360
        return providers_map
1✔
361

362

363
class InspectFilter:
1✔
364

365
    def is_excluded(self, instance: object) -> bool:
1✔
366
        if self._is_werkzeug_local_proxy(instance):
1✔
367
            return True
1✔
368
        elif self._is_starlette_request_cls(instance):
1✔
369
            return True
1✔
370
        elif self._is_builtin(instance):
1✔
371
            return True
1✔
372
        else:
373
            return False
1✔
374

375
    def _is_werkzeug_local_proxy(self, instance: object) -> bool:
1✔
376
        return werkzeug and isinstance(instance, werkzeug.local.LocalProxy)
1✔
377

378
    def _is_starlette_request_cls(self, instance: object) -> bool:
1✔
379
        return starlette \
1✔
380
               and isinstance(instance, type) \
381
               and _safe_is_subclass(instance, starlette.requests.Request)
382

383
    def _is_builtin(self, instance: object) -> bool:
1✔
384
        return inspect.isbuiltin(instance)
1✔
385

386

387
def wire(  # noqa: C901
1✔
388
        container: Container,
389
        *,
390
        modules: Optional[Iterable[ModuleType]] = None,
391
        packages: Optional[Iterable[ModuleType]] = None,
392
) -> None:
393
    """Wire container providers with provided packages and modules."""
394
    modules = [*modules] if modules else []
1✔
395

396
    if packages:
1✔
397
        for package in packages:
1✔
398
            modules.extend(_fetch_modules(package))
1✔
399

400
    providers_map = ProvidersMap(container)
1✔
401

402
    for module in modules:
1✔
403
        for member_name, member in inspect.getmembers(module):
1✔
404
            if _inspect_filter.is_excluded(member):
1✔
405
                continue
1✔
406

407
            if _is_marker(member):
1✔
408
                _patch_attribute(module, member_name, member, providers_map)
1✔
409
            elif inspect.isfunction(member):
1✔
410
                _patch_fn(module, member_name, member, providers_map)
1✔
411
            elif inspect.isclass(member):
1✔
412
                cls = member
1✔
413
                try:
1✔
414
                    cls_members = inspect.getmembers(cls)
1✔
415
                except Exception:  # noqa
×
416
                    # Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/441
417
                    continue
×
418
                else:
419
                    for cls_member_name, cls_member in cls_members:
1✔
420
                        if _is_marker(cls_member):
1✔
421
                            _patch_attribute(cls, cls_member_name, cls_member, providers_map)
1✔
422
                        elif _is_method(cls_member):
1✔
423
                            _patch_method(cls, cls_member_name, cls_member, providers_map)
1✔
424

425
        for patched in _patched_registry.get_callables_from_module(module):
1✔
426
            _bind_injections(patched, providers_map)
1✔
427

428

429
def unwire(  # noqa: C901
1✔
430
        *,
431
        modules: Optional[Iterable[ModuleType]] = None,
432
        packages: Optional[Iterable[ModuleType]] = None,
433
) -> None:
434
    """Wire provided packages and modules with previous wired providers."""
435
    modules = [*modules] if modules else []
1✔
436

437
    if packages:
1✔
438
        for package in packages:
1✔
439
            modules.extend(_fetch_modules(package))
1✔
440

441
    for module in modules:
1✔
442
        for name, member in inspect.getmembers(module):
1✔
443
            if inspect.isfunction(member):
1✔
444
                _unpatch(module, name, member)
1✔
445
            elif inspect.isclass(member):
1✔
446
                for method_name, method in inspect.getmembers(member, inspect.isfunction):
1✔
447
                    _unpatch(member, method_name, method)
1✔
448

449
        for patched in _patched_registry.get_callables_from_module(module):
1✔
450
            _unbind_injections(patched)
1✔
451

452
        for patched_attribute in _patched_registry.get_attributes_from_module(module):
1✔
453
            _unpatch_attribute(patched_attribute)
1✔
454
        _patched_registry.clear_module_attributes(module)
1✔
455

456

457
def inject(fn: F) -> F:
1✔
458
    """Decorate callable with injecting decorator."""
459
    reference_injections, reference_closing = _fetch_reference_injections(fn)
1✔
460
    patched = _get_patched(fn, reference_injections, reference_closing)
1✔
461
    return cast(F, patched)
1✔
462

463

464
def _patch_fn(
1✔
465
        module: ModuleType,
466
        name: str,
467
        fn: Callable[..., Any],
468
        providers_map: ProvidersMap,
469
) -> None:
470
    if not _is_patched(fn):
1✔
471
        reference_injections, reference_closing = _fetch_reference_injections(fn)
1✔
472
        if not reference_injections:
1✔
473
            return
1✔
474
        fn = _get_patched(fn, reference_injections, reference_closing)
1✔
475

476
    _bind_injections(fn, providers_map)
1✔
477

478
    setattr(module, name, fn)
1✔
479

480

481
def _patch_method(
1✔
482
        cls: Type,
483
        name: str,
484
        method: Callable[..., Any],
485
        providers_map: ProvidersMap,
486
) -> None:
487
    if hasattr(cls, "__dict__") \
1✔
488
            and name in cls.__dict__ \
489
            and isinstance(cls.__dict__[name], (classmethod, staticmethod)):
490
        method = cls.__dict__[name]
1✔
491
        fn = method.__func__
1✔
492
    else:
493
        fn = method
1✔
494

495
    if not _is_patched(fn):
1✔
496
        reference_injections, reference_closing = _fetch_reference_injections(fn)
1✔
497
        if not reference_injections:
1✔
498
            return
1✔
499
        fn = _get_patched(fn, reference_injections, reference_closing)
×
500

501
    _bind_injections(fn, providers_map)
1✔
502

503
    if isinstance(method, (classmethod, staticmethod)):
1✔
504
        fn = type(method)(fn)
1✔
505

506
    setattr(cls, name, fn)
1✔
507

508

509
def _unpatch(
1✔
510
        module: ModuleType,
511
        name: str,
512
        fn: Callable[..., Any],
513
) -> None:
514
    if hasattr(module, "__dict__") \
1✔
515
            and name in module.__dict__ \
516
            and isinstance(module.__dict__[name], (classmethod, staticmethod)):
517
        method = module.__dict__[name]
1✔
518
        fn = method.__func__
1✔
519

520
    if not _is_patched(fn):
1✔
521
        return
1✔
522

523
    _unbind_injections(fn)
1✔
524

525

526
def _patch_attribute(
1✔
527
        member: Any,
528
        name: str,
529
        marker: "_Marker",
530
        providers_map: ProvidersMap,
531
) -> None:
532
    provider = providers_map.resolve_provider(marker.provider, marker.modifier)
1✔
533
    if provider is None:
1✔
534
        return
1✔
535

536
    _patched_registry.register_attribute(PatchedAttribute(member, name, marker))
1✔
537

538
    if isinstance(marker, Provide):
1✔
539
        instance = provider()
1✔
540
        setattr(member, name, instance)
1✔
541
    elif isinstance(marker, Provider):
1✔
542
        setattr(member, name, provider)
1✔
543
    else:
544
        raise Exception(f"Unknown type of marker {marker}")
1✔
545

546

547
def _unpatch_attribute(patched: PatchedAttribute) -> None:
1✔
548
    setattr(patched.member, patched.name, patched.marker)
1✔
549

550

551
def _fetch_reference_injections(  # noqa: C901
1✔
552
        fn: Callable[..., Any],
553
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
554
    # Hotfix, see:
555
    # - https://github.com/ets-labs/python-dependency-injector/issues/362
556
    # - https://github.com/ets-labs/python-dependency-injector/issues/398
557
    if GenericAlias and any((
1✔
558
                fn is GenericAlias,
559
                getattr(fn, "__func__", None) is GenericAlias
560
            )):
561
        fn = fn.__init__
1✔
562

563
    try:
1✔
564
        signature = inspect.signature(fn)
1✔
565
    except ValueError as exception:
1✔
566
        if "no signature found" in str(exception):
1✔
567
            return {}, {}
1✔
568
        elif "not supported by signature" in str(exception):
×
569
            return {}, {}
×
570
        else:
571
            raise exception
×
572

573
    injections = {}
1✔
574
    closing = {}
1✔
575
    for parameter_name, parameter in signature.parameters.items():
1✔
576
        if not isinstance(parameter.default, _Marker) \
1✔
577
                and not _is_fastapi_depends(parameter.default):
578
            continue
1✔
579

580
        marker = parameter.default
1✔
581

582
        if _is_fastapi_depends(marker):
1✔
583
            marker = marker.dependency
1✔
584

585
            if not isinstance(marker, _Marker):
1✔
586
                continue
1✔
587

588
        if isinstance(marker, Closing):
1✔
589
            marker = marker.provider
1✔
590
            closing[parameter_name] = marker
1✔
591

592
        injections[parameter_name] = marker
1✔
593
    return injections, closing
1✔
594

595

596
def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
1✔
597
    patched_callable = _patched_registry.get_callable(fn)
1✔
598
    if patched_callable is None:
1✔
599
        return
×
600

601
    for injection, marker in patched_callable.reference_injections.items():
1✔
602
        provider = providers_map.resolve_provider(marker.provider, marker.modifier)
1✔
603

604
        if provider is None:
1✔
605
            continue
1✔
606

607
        if isinstance(marker, Provide):
1✔
608
            patched_callable.add_injection(injection, provider)
1✔
609
        elif isinstance(marker, Provider):
1✔
610
            if isinstance(provider, providers.Delegate):
1✔
611
                patched_callable.add_injection(injection, provider)
1✔
612
            else:
613
                patched_callable.add_injection(injection, provider.provider)
1✔
614

615
        if injection in patched_callable.reference_closing:
1✔
616
            patched_callable.add_closing(injection, provider)
1✔
617

618

619
def _unbind_injections(fn: Callable[..., Any]) -> None:
1✔
620
    patched_callable = _patched_registry.get_callable(fn)
1✔
621
    if patched_callable is None:
1✔
622
        return
×
623
    patched_callable.unwind_injections()
1✔
624

625

626
def _fetch_modules(package):
1✔
627
    modules = [package]
1✔
628
    if not hasattr(package, "__path__") or not hasattr(package, "__name__"):
1✔
629
        return modules
1✔
630
    for module_info in pkgutil.walk_packages(
1✔
631
            path=package.__path__,
632
            prefix=package.__name__ + ".",
633
    ):
634
        module = importlib.import_module(module_info.name)
1✔
635
        modules.append(module)
1✔
636
    return modules
1✔
637

638

639
def _is_method(member) -> bool:
1✔
640
    return inspect.ismethod(member) or inspect.isfunction(member)
1✔
641

642

643
def _is_marker(member) -> bool:
1✔
644
    return isinstance(member, _Marker)
1✔
645

646

647
def _get_patched(
1✔
648
        fn: F,
649
        reference_injections: Dict[Any, Any],
650
        reference_closing: Dict[Any, Any],
651
) -> F:
652
    patched_object = PatchedCallable(
1✔
653
        original=fn,
654
        reference_injections=reference_injections,
655
        reference_closing=reference_closing,
656
    )
657

658
    if inspect.iscoroutinefunction(fn):
1✔
659
        patched = _get_async_patched(fn, patched_object)
1✔
660
    else:
661
        patched = _get_sync_patched(fn, patched_object)
1✔
662

663
    patched_object.patched = patched
1✔
664
    _patched_registry.register_callable(patched_object)
1✔
665

666
    return patched
1✔
667

668

669
def _is_fastapi_depends(param: Any) -> bool:
1✔
670
    return fastapi and isinstance(param, fastapi.params.Depends)
1✔
671

672

673
def _is_patched(fn) -> bool:
1✔
674
    return _patched_registry.has_callable(fn)
1✔
675

676

677
def _is_declarative_container(instance: Any) -> bool:
1✔
678
    return (isinstance(instance, type)
1✔
679
            and getattr(instance, "__IS_CONTAINER__", False) is True
680
            and getattr(instance, "declarative_parent", None) is None)
681

682

683
def _safe_is_subclass(instance: Any, cls: Type) -> bool:
1✔
684
    try:
1✔
685
        return issubclass(instance, cls)
1✔
686
    except TypeError:
×
687
        return False
×
688

689

690
class Modifier:
1✔
691

692
    def modify(
1✔
693
            self,
694
            provider: providers.ConfigurationOption,
695
            providers_map: ProvidersMap,
696
    ) -> providers.Provider:
697
        ...
×
698

699

700
class TypeModifier(Modifier):
1✔
701

702
    def __init__(self, type_: Type) -> None:
1✔
703
        self.type_ = type_
1✔
704

705
    def modify(
1✔
706
            self,
707
            provider: providers.ConfigurationOption,
708
            providers_map: ProvidersMap,
709
    ) -> providers.Provider:
710
        return provider.as_(self.type_)
1✔
711

712

713
def as_int() -> TypeModifier:
1✔
714
    """Return int type modifier."""
715
    return TypeModifier(int)
1✔
716

717

718
def as_float() -> TypeModifier:
1✔
719
    """Return float type modifier."""
720
    return TypeModifier(float)
1✔
721

722

723
def as_(type_: Type) -> TypeModifier:
1✔
724
    """Return custom type modifier."""
725
    return TypeModifier(type_)
1✔
726

727

728
class RequiredModifier(Modifier):
1✔
729

730
    def __init__(self) -> None:
1✔
731
        self.type_modifier = None
1✔
732

733
    def as_int(self) -> "RequiredModifier":
1✔
734
        self.type_modifier = TypeModifier(int)
1✔
735
        return self
1✔
736

737
    def as_float(self) -> "RequiredModifier":
1✔
738
        self.type_modifier = TypeModifier(float)
1✔
739
        return self
1✔
740

741
    def as_(self, type_: Type) -> "RequiredModifier":
1✔
742
        self.type_modifier = TypeModifier(type_)
1✔
743
        return self
1✔
744

745
    def modify(
1✔
746
            self,
747
            provider: providers.ConfigurationOption,
748
            providers_map: ProvidersMap,
749
    ) -> providers.Provider:
750
        provider = provider.required()
1✔
751
        if self.type_modifier:
1✔
752
            provider = provider.as_(self.type_modifier.type_)
1✔
753
        return provider
1✔
754

755

756
def required() -> RequiredModifier:
1✔
757
    """Return required modifier."""
758
    return RequiredModifier()
1✔
759

760

761
class InvariantModifier(Modifier):
1✔
762

763
    def __init__(self, id: str) -> None:
1✔
764
        self.id = id
1✔
765

766
    def modify(
1✔
767
            self,
768
            provider: providers.ConfigurationOption,
769
            providers_map: ProvidersMap,
770
    ) -> providers.Provider:
771
        invariant_segment = providers_map.resolve_provider(self.id)
1✔
772
        return provider[invariant_segment]
1✔
773

774

775
def invariant(id: str) -> InvariantModifier:
1✔
776
    """Return invariant modifier."""
777
    return InvariantModifier(id)
1✔
778

779

780
class ProvidedInstance(Modifier):
1✔
781

782
    TYPE_ATTRIBUTE = "attr"
1✔
783
    TYPE_ITEM = "item"
1✔
784
    TYPE_CALL = "call"
1✔
785

786
    def __init__(self) -> None:
1✔
787
        self.segments = []
1✔
788

789
    def __getattr__(self, item):
1✔
790
        self.segments.append((self.TYPE_ATTRIBUTE, item))
1✔
791
        return self
1✔
792

793
    def __getitem__(self, item):
1✔
794
        self.segments.append((self.TYPE_ITEM, item))
1✔
795
        return self
1✔
796

797
    def call(self):
1✔
798
        self.segments.append((self.TYPE_CALL, None))
1✔
799
        return self
1✔
800

801
    def modify(
1✔
802
            self,
803
            provider: providers.Provider,
804
            providers_map: ProvidersMap,
805
    ) -> providers.Provider:
806
        provider = provider.provided
1✔
807
        for type_, value in self.segments:
1✔
808
            if type_ == ProvidedInstance.TYPE_ATTRIBUTE:
1✔
809
                provider = getattr(provider, value)
1✔
810
            elif type_ == ProvidedInstance.TYPE_ITEM:
1✔
811
                provider = provider[value]
1✔
812
            elif type_ == ProvidedInstance.TYPE_CALL:
1✔
813
                provider = provider.call()
1✔
814
        return provider
1✔
815

816

817
def provided() -> ProvidedInstance:
1✔
818
    """Return provided instance modifier."""
819
    return ProvidedInstance()
1✔
820

821

822
class ClassGetItemMeta(GenericMeta):
1✔
823
    def __getitem__(cls, item):
1✔
824
        # Spike for Python 3.6
825
        if isinstance(item, tuple):
1✔
826
            return cls(*item)
1✔
827
        return cls(item)
1✔
828

829

830
class _Marker(Generic[T], metaclass=ClassGetItemMeta):
1✔
831

832
    __IS_MARKER__ = True
1✔
833

834
    def __init__(
1✔
835
            self,
836
            provider: Union[providers.Provider, Container, str],
837
            modifier: Optional[Modifier] = None,
838
    ) -> None:
839
        if _is_declarative_container(provider):
1✔
840
            provider = provider.__self__
1✔
841
        self.provider = provider
1✔
842
        self.modifier = modifier
1✔
843

844
    def __class_getitem__(cls, item) -> T:
1✔
845
        if isinstance(item, tuple):
×
846
            return cls(*item)
×
847
        return cls(item)
×
848

849
    def __call__(self) -> T:
1✔
850
        return self
1✔
851

852

853
class Provide(_Marker):
1✔
854
    ...
1✔
855

856

857
class Provider(_Marker):
1✔
858
    ...
1✔
859

860

861
class Closing(_Marker):
1✔
862
    ...
1✔
863

864

865
class AutoLoader:
1✔
866
    """Auto-wiring module loader.
867

868
    Automatically wire containers when modules are imported.
869
    """
870

871
    def __init__(self) -> None:
1✔
872
        self.containers = []
1✔
873
        self._path_hook = None
1✔
874

875
    def register_containers(self, *containers) -> None:
1✔
876
        self.containers.extend(containers)
1✔
877

878
        if not self.installed:
1✔
879
            self.install()
1✔
880

881
    def unregister_containers(self, *containers) -> None:
1✔
882
        for container in containers:
1✔
883
            self.containers.remove(container)
1✔
884

885
        if not self.containers:
1✔
886
            self.uninstall()
1✔
887

888
    def wire_module(self, module) -> None:
1✔
889
        for container in self.containers:
1✔
890
            container.wire(modules=[module])
1✔
891

892
    @property
1✔
893
    def installed(self) -> bool:
1✔
894
        return self._path_hook in sys.path_hooks
1✔
895

896
    def install(self) -> None:
1✔
897
        if self.installed:
1✔
898
            return
×
899

900
        loader = self
1✔
901

902
        class SourcelessFileLoader(importlib.machinery.SourcelessFileLoader):
1✔
903
            def exec_module(self, module):
1✔
904
                super().exec_module(module)
×
905
                loader.wire_module(module)
×
906

907
        class SourceFileLoader(importlib.machinery.SourceFileLoader):
1✔
908
            def exec_module(self, module):
1✔
909
                super().exec_module(module)
1✔
910
                loader.wire_module(module)
1✔
911

912
        class ExtensionFileLoader(importlib.machinery.ExtensionFileLoader):
1✔
913
            ...
1✔
914

915
        loader_details = [
1✔
916
            (SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
917
            (SourceFileLoader, importlib.machinery.SOURCE_SUFFIXES),
918
            (ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES),
919
        ]
920

921
        self._path_hook = importlib.machinery.FileFinder.path_hook(*loader_details)
1✔
922

923
        sys.path_hooks.insert(0, self._path_hook)
1✔
924
        sys.path_importer_cache.clear()
1✔
925
        importlib.invalidate_caches()
1✔
926

927
    def uninstall(self) -> None:
1✔
928
        if not self.installed:
1✔
929
            return
×
930

931
        sys.path_hooks.remove(self._path_hook)
1✔
932
        sys.path_importer_cache.clear()
1✔
933
        importlib.invalidate_caches()
1✔
934

935

936
def register_loader_containers(*containers: Container) -> None:
1✔
937
    """Register containers in auto-wiring module loader."""
938
    _loader.register_containers(*containers)
1✔
939

940

941
def unregister_loader_containers(*containers: Container) -> None:
1✔
942
    """Unregister containers from auto-wiring module loader."""
943
    _loader.unregister_containers(*containers)
1✔
944

945

946
def install_loader() -> None:
1✔
947
    """Install auto-wiring module loader hook."""
948
    _loader.install()
×
949

950

951
def uninstall_loader() -> None:
1✔
952
    """Uninstall auto-wiring module loader hook."""
953
    _loader.uninstall()
×
954

955

956
def is_loader_installed() -> bool:
1✔
957
    """Check if auto-wiring module loader hook is installed."""
958
    return _loader.installed
×
959

960

961
_patched_registry = PatchedRegistry()
1✔
962
_inspect_filter = InspectFilter()
1✔
963
_loader = AutoLoader()
1✔
964

965
# Optimizations
966
from ._cwiring import _get_sync_patched  # noqa
1✔
967
from ._cwiring import _async_inject  # noqa
1✔
968

969

970
# Wiring uses the following Python wrapper because there is
971
# no possibility to compile a first-type citizen coroutine in Cython.
972
def _get_async_patched(fn: F, patched: PatchedCallable) -> F:
1✔
973
    @functools.wraps(fn)
1✔
974
    async def _patched(*args, **kwargs):
1✔
975
        return await _async_inject(
1✔
976
            fn,
977
            args,
978
            kwargs,
979
            patched.injections,
980
            patched.closing,
981
        )
982
    return _patched
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc