• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 11808037861

13 Nov 2024 12:16AM UTC coverage: 91.334% (-0.1%) from 91.44%
11808037861

Pull #822

github

web-flow
Merge 76a359499 into bdf09a6ed
Pull Request #822: Added support for asyncio eager task factories

39 of 48 new or added lines in 2 files covered. (81.25%)

2 existing lines in 1 file now uncovered.

4880 of 5343 relevant lines covered (91.33%)

8.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.49
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import os
10✔
8
import socket
10✔
9
import sys
10✔
10
import threading
10✔
11
import weakref
10✔
12
from asyncio import (
10✔
13
    AbstractEventLoop,
14
    CancelledError,
15
    all_tasks,
16
    create_task,
17
    current_task,
18
    get_running_loop,
19
    sleep,
20
)
21
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
22
from collections import OrderedDict, deque
10✔
23
from collections.abc import (
10✔
24
    AsyncGenerator,
25
    AsyncIterator,
26
    Awaitable,
27
    Callable,
28
    Collection,
29
    Coroutine,
30
    Iterable,
31
    Iterator,
32
    MutableMapping,
33
    Sequence,
34
)
35
from concurrent.futures import Future
10✔
36
from contextlib import AbstractContextManager, suppress
10✔
37
from contextvars import Context, copy_context
10✔
38
from dataclasses import dataclass
10✔
39
from functools import partial, wraps
10✔
40
from inspect import (
10✔
41
    CORO_RUNNING,
42
    CORO_SUSPENDED,
43
    getcoroutinestate,
44
    iscoroutine,
45
)
46
from io import IOBase
10✔
47
from os import PathLike
10✔
48
from queue import Queue
10✔
49
from signal import Signals
10✔
50
from socket import AddressFamily, SocketKind
10✔
51
from threading import Thread
10✔
52
from types import TracebackType
10✔
53
from typing import (
10✔
54
    IO,
55
    Any,
56
    Optional,
57
    TypeVar,
58
    cast,
59
)
60
from weakref import WeakKeyDictionary
10✔
61

62
import sniffio
10✔
63

64
from .. import (
10✔
65
    CapacityLimiterStatistics,
66
    EventStatistics,
67
    LockStatistics,
68
    TaskInfo,
69
    abc,
70
)
71
from .._core._eventloop import claim_worker_thread, threadlocals
10✔
72
from .._core._exceptions import (
10✔
73
    BrokenResourceError,
74
    BusyResourceError,
75
    ClosedResourceError,
76
    EndOfStream,
77
    WouldBlock,
78
    iterate_exceptions,
79
)
80
from .._core._sockets import convert_ipv6_sockaddr
10✔
81
from .._core._streams import create_memory_object_stream
10✔
82
from .._core._synchronization import (
10✔
83
    CapacityLimiter as BaseCapacityLimiter,
84
)
85
from .._core._synchronization import Event as BaseEvent
10✔
86
from .._core._synchronization import Lock as BaseLock
10✔
87
from .._core._synchronization import (
10✔
88
    ResourceGuard,
89
    SemaphoreStatistics,
90
)
91
from .._core._synchronization import Semaphore as BaseSemaphore
10✔
92
from .._core._tasks import CancelScope as BaseCancelScope
10✔
93
from ..abc import (
10✔
94
    AsyncBackend,
95
    IPSockAddrType,
96
    SocketListener,
97
    UDPPacketType,
98
    UNIXDatagramPacketType,
99
)
100
from ..abc._eventloop import StrOrBytesPath
10✔
101
from ..lowlevel import RunVar
10✔
102
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
103

104
if sys.version_info >= (3, 10):
10✔
105
    from typing import ParamSpec
7✔
106
else:
107
    from typing_extensions import ParamSpec
3✔
108

109
if sys.version_info >= (3, 11):
10✔
110
    from asyncio import Runner
5✔
111
    from typing import TypeVarTuple, Unpack
5✔
112
else:
113
    import contextvars
5✔
114
    import enum
5✔
115
    import signal
5✔
116
    from asyncio import coroutines, events, exceptions, tasks
5✔
117

118
    from exceptiongroup import BaseExceptionGroup
5✔
119
    from typing_extensions import TypeVarTuple, Unpack
5✔
120

121
    class _State(enum.Enum):
5✔
122
        CREATED = "created"
5✔
123
        INITIALIZED = "initialized"
5✔
124
        CLOSED = "closed"
5✔
125

126
    class Runner:
5✔
127
        # Copied from CPython 3.11
128
        def __init__(
5✔
129
            self,
130
            *,
131
            debug: bool | None = None,
132
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
133
        ):
134
            self._state = _State.CREATED
5✔
135
            self._debug = debug
5✔
136
            self._loop_factory = loop_factory
5✔
137
            self._loop: AbstractEventLoop | None = None
5✔
138
            self._context = None
5✔
139
            self._interrupt_count = 0
5✔
140
            self._set_event_loop = False
5✔
141

142
        def __enter__(self) -> Runner:
5✔
143
            self._lazy_init()
5✔
144
            return self
5✔
145

146
        def __exit__(
5✔
147
            self,
148
            exc_type: type[BaseException],
149
            exc_val: BaseException,
150
            exc_tb: TracebackType,
151
        ) -> None:
152
            self.close()
5✔
153

154
        def close(self) -> None:
5✔
155
            """Shutdown and close event loop."""
156
            if self._state is not _State.INITIALIZED:
5✔
157
                return
×
158
            try:
5✔
159
                loop = self._loop
5✔
160
                _cancel_all_tasks(loop)
5✔
161
                loop.run_until_complete(loop.shutdown_asyncgens())
5✔
162
                if hasattr(loop, "shutdown_default_executor"):
5✔
163
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
164
                else:
165
                    loop.run_until_complete(_shutdown_default_executor(loop))
×
166
            finally:
167
                if self._set_event_loop:
5✔
168
                    events.set_event_loop(None)
5✔
169
                loop.close()
5✔
170
                self._loop = None
5✔
171
                self._state = _State.CLOSED
5✔
172

173
        def get_loop(self) -> AbstractEventLoop:
5✔
174
            """Return embedded event loop."""
175
            self._lazy_init()
5✔
176
            return self._loop
5✔
177

178
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
5✔
179
            """Run a coroutine inside the embedded event loop."""
180
            if not coroutines.iscoroutine(coro):
5✔
181
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
182

183
            if events._get_running_loop() is not None:
5✔
184
                # fail fast with short traceback
185
                raise RuntimeError(
×
186
                    "Runner.run() cannot be called from a running event loop"
187
                )
188

189
            self._lazy_init()
5✔
190

191
            if context is None:
5✔
192
                context = self._context
5✔
193
            task = context.run(self._loop.create_task, coro)
5✔
194

195
            if (
5✔
196
                threading.current_thread() is threading.main_thread()
197
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
198
            ):
199
                sigint_handler = partial(self._on_sigint, main_task=task)
5✔
200
                try:
5✔
201
                    signal.signal(signal.SIGINT, sigint_handler)
5✔
202
                except ValueError:
×
203
                    # `signal.signal` may throw if `threading.main_thread` does
204
                    # not support signals (e.g. embedded interpreter with signals
205
                    # not registered - see gh-91880)
206
                    sigint_handler = None
×
207
            else:
208
                sigint_handler = None
5✔
209

210
            self._interrupt_count = 0
5✔
211
            try:
5✔
212
                return self._loop.run_until_complete(task)
5✔
213
            except exceptions.CancelledError:
5✔
214
                if self._interrupt_count > 0:
×
215
                    uncancel = getattr(task, "uncancel", None)
×
216
                    if uncancel is not None and uncancel() == 0:
×
217
                        raise KeyboardInterrupt()
×
218
                raise  # CancelledError
×
219
            finally:
220
                if (
5✔
221
                    sigint_handler is not None
222
                    and signal.getsignal(signal.SIGINT) is sigint_handler
223
                ):
224
                    signal.signal(signal.SIGINT, signal.default_int_handler)
6✔
225

226
        def _lazy_init(self) -> None:
5✔
227
            if self._state is _State.CLOSED:
5✔
228
                raise RuntimeError("Runner is closed")
×
229
            if self._state is _State.INITIALIZED:
5✔
230
                return
5✔
231
            if self._loop_factory is None:
5✔
232
                self._loop = events.new_event_loop()
6✔
233
                if not self._set_event_loop:
6✔
234
                    # Call set_event_loop only once to avoid calling
235
                    # attach_loop multiple times on child watchers
236
                    events.set_event_loop(self._loop)
6✔
237
                    self._set_event_loop = True
6✔
238
            else:
239
                self._loop = self._loop_factory()
4✔
240
            if self._debug is not None:
5✔
241
                self._loop.set_debug(self._debug)
5✔
242
            self._context = contextvars.copy_context()
5✔
243
            self._state = _State.INITIALIZED
5✔
244

245
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
6✔
246
            self._interrupt_count += 1
×
UNCOV
247
            if self._interrupt_count == 1 and not main_task.done():
×
248
                main_task.cancel()
×
249
                # wakeup loop if it is blocked by select() with long timeout
250
                self._loop.call_soon_threadsafe(lambda: None)
×
251
                return
×
252
            raise KeyboardInterrupt()
×
253

254
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
5✔
255
        to_cancel = tasks.all_tasks(loop)
5✔
256
        if not to_cancel:
5✔
257
            return
5✔
258

259
        for task in to_cancel:
5✔
260
            task.cancel()
5✔
261

262
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
5✔
263

264
        for task in to_cancel:
5✔
265
            if task.cancelled():
5✔
266
                continue
5✔
267
            if task.exception() is not None:
4✔
268
                loop.call_exception_handler(
×
269
                    {
270
                        "message": "unhandled exception during asyncio.run() shutdown",
271
                        "exception": task.exception(),
272
                        "task": task,
273
                    }
274
                )
275

276
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
5✔
277
        """Schedule the shutdown of the default executor."""
278

279
        def _do_shutdown(future: asyncio.futures.Future) -> None:
×
280
            try:
×
281
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
×
282
                loop.call_soon_threadsafe(future.set_result, None)
×
283
            except Exception as ex:
×
284
                loop.call_soon_threadsafe(future.set_exception, ex)
×
285

286
        loop._executor_shutdown_called = True
×
287
        if loop._default_executor is None:
×
288
            return
×
289
        future = loop.create_future()
×
290
        thread = threading.Thread(target=_do_shutdown, args=(future,))
×
291
        thread.start()
×
292
        try:
×
293
            await future
×
294
        finally:
295
            thread.join()
×
296

297

298
T_Retval = TypeVar("T_Retval")
10✔
299
T_contra = TypeVar("T_contra", contravariant=True)
10✔
300
PosArgsT = TypeVarTuple("PosArgsT")
10✔
301
P = ParamSpec("P")
10✔
302

303
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
304

305

306
def find_root_task() -> asyncio.Task:
10✔
307
    root_task = _root_task.get(None)
10✔
308
    if root_task is not None and not root_task.done():
10✔
309
        return root_task
10✔
310

311
    # Look for a task that has been started via run_until_complete()
312
    for task in all_tasks():
10✔
313
        if task._callbacks and not task.done():
10✔
314
            callbacks = [cb for cb, context in task._callbacks]
10✔
315
            for cb in callbacks:
10✔
316
                if (
10✔
317
                    cb is _run_until_complete_cb
318
                    or getattr(cb, "__module__", None) == "uvloop.loop"
319
                ):
320
                    _root_task.set(task)
10✔
321
                    return task
10✔
322

323
    # Look up the topmost task in the AnyIO task tree, if possible
324
    task = cast(asyncio.Task, current_task())
9✔
325
    state = _task_states.get(task)
9✔
326
    if state:
9✔
327
        cancel_scope = state.cancel_scope
9✔
328
        while cancel_scope and cancel_scope._parent_scope is not None:
9✔
329
            cancel_scope = cancel_scope._parent_scope
×
330

331
        if cancel_scope is not None:
9✔
332
            return cast(asyncio.Task, cancel_scope._host_task)
9✔
333

334
    return task
×
335

336

337
def get_callable_name(func: Callable) -> str:
10✔
338
    module = getattr(func, "__module__", None)
10✔
339
    qualname = getattr(func, "__qualname__", None)
10✔
340
    return ".".join([x for x in (module, qualname) if x])
10✔
341

342

343
#
344
# Event loop
345
#
346

347
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
10✔
348

349

350
def _task_started(task: asyncio.Task) -> bool:
10✔
351
    """Return ``True`` if the task has been started and has not finished."""
352
    # The task coro should never be None here, as we never add finished tasks to the
353
    # task list
354
    coro = task.get_coro()
10✔
355
    assert coro is not None
10✔
356
    try:
10✔
357
        return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
358
    except AttributeError:
×
359
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
360
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
361

362

363
#
364
# Timeouts and cancellation
365
#
366

367

368
def is_anyio_cancellation(exc: CancelledError) -> bool:
10✔
369
    return (
10✔
370
        bool(exc.args)
371
        and isinstance(exc.args[0], str)
372
        and exc.args[0].startswith("Cancelled by cancel scope ")
373
    )
374

375

376
class CancelScope(BaseCancelScope):
10✔
377
    def __new__(
10✔
378
        cls, *, deadline: float = math.inf, shield: bool = False
379
    ) -> CancelScope:
380
        return object.__new__(cls)
10✔
381

382
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
383
        self._deadline = deadline
10✔
384
        self._shield = shield
10✔
385
        self._parent_scope: CancelScope | None = None
10✔
386
        self._child_scopes: set[CancelScope] = set()
10✔
387
        self._cancel_called = False
10✔
388
        self._cancelled_caught = False
10✔
389
        self._active = False
10✔
390
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
391
        self._cancel_handle: asyncio.Handle | None = None
10✔
392
        self._tasks: set[asyncio.Task] = set()
10✔
393
        self._host_task: asyncio.Task | None = None
10✔
394
        self._cancel_calls: int = 0
10✔
395
        self._cancelling: int | None = None
10✔
396

397
    def __enter__(self) -> CancelScope:
10✔
398
        if self._active:
10✔
399
            raise RuntimeError(
×
400
                "Each CancelScope may only be used for a single 'with' block"
401
            )
402

403
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
404
        self._tasks.add(host_task)
10✔
405
        try:
10✔
406
            task_state = _task_states[host_task]
10✔
407
        except KeyError:
10✔
408
            task_state = TaskState(None, self)
10✔
409
            _task_states[host_task] = task_state
10✔
410
        else:
411
            self._parent_scope = task_state.cancel_scope
10✔
412
            task_state.cancel_scope = self
10✔
413
            if self._parent_scope is not None:
10✔
414
                self._parent_scope._child_scopes.add(self)
10✔
415
                self._parent_scope._tasks.remove(host_task)
10✔
416

417
        self._timeout()
10✔
418
        self._active = True
10✔
419
        if sys.version_info >= (3, 11):
10✔
420
            self._cancelling = self._host_task.cancelling()
5✔
421

422
        # Start cancelling the host task if the scope was cancelled before entering
423
        if self._cancel_called:
10✔
424
            self._deliver_cancellation(self)
10✔
425

426
        return self
10✔
427

428
    def __exit__(
10✔
429
        self,
430
        exc_type: type[BaseException] | None,
431
        exc_val: BaseException | None,
432
        exc_tb: TracebackType | None,
433
    ) -> bool | None:
434
        del exc_tb
10✔
435

436
        if not self._active:
10✔
437
            raise RuntimeError("This cancel scope is not active")
9✔
438
        if current_task() is not self._host_task:
10✔
439
            raise RuntimeError(
9✔
440
                "Attempted to exit cancel scope in a different task than it was "
441
                "entered in"
442
            )
443

444
        assert self._host_task is not None
10✔
445
        host_task_state = _task_states.get(self._host_task)
10✔
446
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
447
            raise RuntimeError(
9✔
448
                "Attempted to exit a cancel scope that isn't the current tasks's "
449
                "current cancel scope"
450
            )
451

452
        try:
10✔
453
            self._active = False
10✔
454
            if self._timeout_handle:
10✔
455
                self._timeout_handle.cancel()
10✔
456
                self._timeout_handle = None
10✔
457

458
            self._tasks.remove(self._host_task)
10✔
459
            if self._parent_scope is not None:
10✔
460
                self._parent_scope._child_scopes.remove(self)
10✔
461
                self._parent_scope._tasks.add(self._host_task)
10✔
462

463
            host_task_state.cancel_scope = self._parent_scope
10✔
464

465
            # Undo all cancellations done by this scope
466
            if self._cancelling is not None:
10✔
467
                while self._cancel_calls:
5✔
468
                    self._cancel_calls -= 1
5✔
469
                    if self._host_task.uncancel() <= self._cancelling:
5✔
470
                        break
5✔
471

472
            # We only swallow the exception iff it was an AnyIO CancelledError, either
473
            # directly as exc_val or inside an exception group and there are no cancelled
474
            # parent cancel scopes visible to us here
475
            not_swallowed_exceptions = 0
10✔
476
            swallow_exception = False
10✔
477
            if exc_val is not None:
10✔
478
                for exc in iterate_exceptions(exc_val):
10✔
479
                    if self._cancel_called and isinstance(exc, CancelledError):
10✔
480
                        if not (swallow_exception := self._uncancel(exc)):
10✔
481
                            not_swallowed_exceptions += 1
10✔
482
                    else:
483
                        not_swallowed_exceptions += 1
10✔
484

485
            # Restart the cancellation effort in the closest visible, cancelled parent
486
            # scope if necessary
487
            self._restart_cancellation_in_parent()
10✔
488
            return swallow_exception and not not_swallowed_exceptions
10✔
489
        finally:
490
            self._host_task = None
10✔
491
            del exc_val
10✔
492

493
    @property
10✔
494
    def _effectively_cancelled(self) -> bool:
10✔
495
        cancel_scope: CancelScope | None = self
10✔
496
        while cancel_scope is not None:
10✔
497
            if cancel_scope._cancel_called:
10✔
498
                return True
10✔
499

500
            if cancel_scope.shield:
10✔
501
                return False
9✔
502

503
            cancel_scope = cancel_scope._parent_scope
10✔
504

505
        return False
10✔
506

507
    @property
10✔
508
    def _parent_cancellation_is_visible_to_us(self) -> bool:
10✔
509
        return (
10✔
510
            self._parent_scope is not None
511
            and not self.shield
512
            and self._parent_scope._effectively_cancelled
513
        )
514

515
    def _uncancel(self, cancelled_exc: CancelledError) -> bool:
10✔
516
        if self._host_task is None:
10✔
517
            self._cancel_calls = 0
×
518
            return True
×
519

520
        while True:
7✔
521
            if is_anyio_cancellation(cancelled_exc):
10✔
522
                # Only swallow the cancellation exception if it's an AnyIO cancel
523
                # exception and there are no other cancel scopes down the line pending
524
                # cancellation
525
                self._cancelled_caught = (
10✔
526
                    self._effectively_cancelled
527
                    and not self._parent_cancellation_is_visible_to_us
528
                )
529
                return self._cancelled_caught
10✔
530

531
            # Sometimes third party frameworks catch a CancelledError and raise a new
532
            # one, so as a workaround we have to look at the previous ones in
533
            # __context__ too for a matching cancel message
534
            if isinstance(cancelled_exc.__context__, CancelledError):
9✔
535
                cancelled_exc = cancelled_exc.__context__
5✔
536
                continue
5✔
537

538
            return False
9✔
539

540
    def _timeout(self) -> None:
10✔
541
        if self._deadline != math.inf:
10✔
542
            loop = get_running_loop()
10✔
543
            if loop.time() >= self._deadline:
10✔
544
                self.cancel()
10✔
545
            else:
546
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
547

548
    def _deliver_cancellation(self, origin: CancelScope) -> bool:
10✔
549
        """
550
        Deliver cancellation to directly contained tasks and nested cancel scopes.
551

552
        Schedule another run at the end if we still have tasks eligible for
553
        cancellation.
554

555
        :param origin: the cancel scope that originated the cancellation
556
        :return: ``True`` if the delivery needs to be retried on the next cycle
557

558
        """
559
        should_retry = False
10✔
560
        current = current_task()
10✔
561
        for task in self._tasks:
10✔
562
            should_retry = True
10✔
563
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
564
                continue
9✔
565

566
            # The task is eligible for cancellation if it has started
567
            if task is not current and (task is self._host_task or _task_started(task)):
10✔
568
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
569
                if not isinstance(waiter, asyncio.Future) or not waiter.done():
10✔
570
                    task.cancel(f"Cancelled by cancel scope {id(origin):x}")
10✔
571
                    if task is origin._host_task:
10✔
572
                        origin._cancel_calls += 1
10✔
573

574
        # Deliver cancellation to child scopes that aren't shielded or running their own
575
        # cancellation callbacks
576
        for scope in self._child_scopes:
10✔
577
            if not scope._shield and not scope.cancel_called:
10✔
578
                should_retry = scope._deliver_cancellation(origin) or should_retry
10✔
579

580
        # Schedule another callback if there are still tasks left
581
        if origin is self:
10✔
582
            if should_retry:
10✔
583
                self._cancel_handle = get_running_loop().call_soon(
10✔
584
                    self._deliver_cancellation, origin
585
                )
586
            else:
587
                self._cancel_handle = None
10✔
588

589
        return should_retry
10✔
590

591
    def _restart_cancellation_in_parent(self) -> None:
10✔
592
        """
593
        Restart the cancellation effort in the closest directly cancelled parent scope.
594

595
        """
596
        scope = self._parent_scope
10✔
597
        while scope is not None:
10✔
598
            if scope._cancel_called:
10✔
599
                if scope._cancel_handle is None:
10✔
600
                    scope._deliver_cancellation(scope)
10✔
601

602
                break
10✔
603

604
            # No point in looking beyond any shielded scope
605
            if scope._shield:
10✔
606
                break
9✔
607

608
            scope = scope._parent_scope
10✔
609

610
    def cancel(self) -> None:
10✔
611
        if not self._cancel_called:
10✔
612
            if self._timeout_handle:
10✔
613
                self._timeout_handle.cancel()
10✔
614
                self._timeout_handle = None
10✔
615

616
            self._cancel_called = True
10✔
617
            if self._host_task is not None:
10✔
618
                self._deliver_cancellation(self)
10✔
619

620
    @property
10✔
621
    def deadline(self) -> float:
10✔
622
        return self._deadline
9✔
623

624
    @deadline.setter
10✔
625
    def deadline(self, value: float) -> None:
10✔
626
        self._deadline = float(value)
9✔
627
        if self._timeout_handle is not None:
9✔
628
            self._timeout_handle.cancel()
9✔
629
            self._timeout_handle = None
9✔
630

631
        if self._active and not self._cancel_called:
9✔
632
            self._timeout()
9✔
633

634
    @property
10✔
635
    def cancel_called(self) -> bool:
10✔
636
        return self._cancel_called
10✔
637

638
    @property
10✔
639
    def cancelled_caught(self) -> bool:
10✔
640
        return self._cancelled_caught
10✔
641

642
    @property
10✔
643
    def shield(self) -> bool:
10✔
644
        return self._shield
10✔
645

646
    @shield.setter
10✔
647
    def shield(self, value: bool) -> None:
10✔
648
        if self._shield != value:
10✔
649
            self._shield = value
10✔
650
            if not value:
10✔
651
                self._restart_cancellation_in_parent()
9✔
652

653

654
#
655
# Task states
656
#
657

658

659
class TaskState:
10✔
660
    """
661
    Encapsulates auxiliary task information that cannot be added to the Task instance
662
    itself because there are no guarantees about its implementation.
663
    """
664

665
    __slots__ = "parent_id", "cancel_scope", "__weakref__"
10✔
666

667
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
10✔
668
        self.parent_id = parent_id
10✔
669
        self.cancel_scope = cancel_scope
10✔
670

671

672
class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task | None", TaskState]):
10✔
673
    def __init__(self) -> None:
10✔
674
        self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
10✔
675
        self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {}
10✔
676

677
    def __getitem__(self, key: Awaitable[Any] | asyncio.Task | None, /) -> TaskState:
10✔
678
        assert isinstance(key, asyncio.Task)
10✔
679
        try:
10✔
680
            return self._task_states[key]
10✔
681
        except KeyError:
10✔
682
            if coro := key.get_coro():
10✔
683
                if state := self._preliminary_task_states.get(coro):
10✔
NEW
684
                    return state
×
685

686
        raise KeyError(key)
10✔
687

688
    def __setitem__(
10✔
689
        self, key: asyncio.Task | Awaitable[Any] | None, value: TaskState, /
690
    ) -> None:
691
        if isinstance(key, asyncio.Task):
10✔
692
            self._task_states[key] = value
10✔
693
        elif key is None:
10✔
NEW
694
            raise ValueError("cannot insert None")
×
695
        else:
696
            self._preliminary_task_states[key] = value
10✔
697

698
    def __delitem__(self, key: asyncio.Task | Awaitable[Any] | None, /) -> None:
10✔
699
        if isinstance(key, asyncio.Task):
10✔
700
            del self._task_states[key]
10✔
701
        elif key is None:
10✔
NEW
702
            raise KeyError(key)
×
703
        else:
704
            del self._preliminary_task_states[key]
10✔
705

706
    def __len__(self) -> int:
10✔
NEW
707
        return len(self._task_states) + len(self._preliminary_task_states)
×
708

709
    def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]:
10✔
NEW
710
        yield from self._task_states
×
NEW
711
        yield from self._preliminary_task_states
×
712

713

714
_task_states = TaskStateStore()
10✔
715

716

717
#
718
# Task groups
719
#
720

721

722
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
723
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
724
        self._future = future
10✔
725
        self._parent_id = parent_id
10✔
726

727
    def started(self, value: T_contra | None = None) -> None:
10✔
728
        try:
10✔
729
            self._future.set_result(value)
10✔
730
        except asyncio.InvalidStateError:
9✔
731
            if not self._future.cancelled():
9✔
732
                raise RuntimeError(
9✔
733
                    "called 'started' twice on the same task status"
734
                ) from None
735

736
        task = cast(asyncio.Task, current_task())
10✔
737
        _task_states[task].parent_id = self._parent_id
10✔
738

739

740
async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None:
10✔
741
    tasks = set(tasks)
10✔
742
    waiter = get_running_loop().create_future()
10✔
743

744
    def on_completion(task: asyncio.Task[object]) -> None:
10✔
745
        tasks.discard(task)
10✔
746
        if not tasks and not waiter.done():
10✔
747
            waiter.set_result(None)
10✔
748

749
    for task in tasks:
10✔
750
        task.add_done_callback(on_completion)
10✔
751
        del task
10✔
752

753
    try:
10✔
754
        await waiter
10✔
755
    finally:
756
        while tasks:
10✔
757
            tasks.pop().remove_done_callback(on_completion)
10✔
758

759

760
class TaskGroup(abc.TaskGroup):
10✔
761
    def __init__(self) -> None:
10✔
762
        self.cancel_scope: CancelScope = CancelScope()
10✔
763
        self._active = False
10✔
764
        self._exceptions: list[BaseException] = []
10✔
765
        self._tasks: set[asyncio.Task] = set()
10✔
766

767
    async def __aenter__(self) -> TaskGroup:
10✔
768
        self.cancel_scope.__enter__()
10✔
769
        self._active = True
10✔
770
        return self
10✔
771

772
    async def __aexit__(
10✔
773
        self,
774
        exc_type: type[BaseException] | None,
775
        exc_val: BaseException | None,
776
        exc_tb: TracebackType | None,
777
    ) -> bool | None:
778
        try:
10✔
779
            if exc_val is not None:
10✔
780
                self.cancel_scope.cancel()
10✔
781
                if not isinstance(exc_val, CancelledError):
10✔
782
                    self._exceptions.append(exc_val)
10✔
783

784
            try:
10✔
785
                if self._tasks:
10✔
786
                    with CancelScope() as wait_scope:
10✔
787
                        while self._tasks:
10✔
788
                            try:
10✔
789
                                await _wait(self._tasks)
10✔
790
                            except CancelledError as exc:
10✔
791
                                # Shield the scope against further cancellation attempts,
792
                                # as they're not productive (#695)
793
                                wait_scope.shield = True
10✔
794
                                self.cancel_scope.cancel()
10✔
795

796
                                # Set exc_val from the cancellation exception if it was
797
                                # previously unset. However, we should not replace a native
798
                                # cancellation exception with one raise by a cancel scope.
799
                                if exc_val is None or (
10✔
800
                                    isinstance(exc_val, CancelledError)
801
                                    and not is_anyio_cancellation(exc)
802
                                ):
803
                                    exc_val = exc
10✔
804
                else:
805
                    # If there are no child tasks to wait on, run at least one checkpoint
806
                    # anyway
807
                    await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
808

809
                self._active = False
10✔
810
                if self._exceptions:
10✔
811
                    raise BaseExceptionGroup(
10✔
812
                        "unhandled errors in a TaskGroup", self._exceptions
813
                    )
814
                elif exc_val:
10✔
815
                    raise exc_val
10✔
816
            except BaseException as exc:
10✔
817
                if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
10✔
818
                    return True
10✔
819

820
                raise
10✔
821

822
            return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
823
        finally:
824
            del exc_val, exc_tb, self._exceptions
10✔
825

826
    def _spawn(
10✔
827
        self,
828
        func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
829
        args: tuple[Unpack[PosArgsT]],
830
        name: object,
831
        task_status_future: asyncio.Future | None = None,
832
    ) -> asyncio.Task:
833
        def task_done(_task: asyncio.Task) -> None:
10✔
834
            # task_state = _task_states[_task]
835
            assert task_state.cancel_scope is not None
10✔
836
            assert _task in task_state.cancel_scope._tasks
10✔
837
            task_state.cancel_scope._tasks.remove(_task)
10✔
838
            self._tasks.remove(task)
10✔
839
            del _task_states[_task]
10✔
840

841
            try:
10✔
842
                exc = _task.exception()
10✔
843
            except CancelledError as e:
10✔
844
                while isinstance(e.__context__, CancelledError):
10✔
845
                    e = e.__context__
5✔
846

847
                exc = e
10✔
848

849
            if exc is not None:
10✔
850
                # The future can only be in the cancelled state if the host task was
851
                # cancelled, so return immediately instead of adding one more
852
                # CancelledError to the exceptions list
853
                if task_status_future is not None and task_status_future.cancelled():
10✔
854
                    return
9✔
855

856
                if task_status_future is None or task_status_future.done():
10✔
857
                    if not isinstance(exc, CancelledError):
10✔
858
                        self._exceptions.append(exc)
10✔
859

860
                    if not self.cancel_scope._effectively_cancelled:
10✔
861
                        self.cancel_scope.cancel()
10✔
862
                else:
863
                    task_status_future.set_exception(exc)
9✔
864
            elif task_status_future is not None and not task_status_future.done():
10✔
865
                task_status_future.set_exception(
9✔
866
                    RuntimeError("Child exited without calling task_status.started()")
867
                )
868

869
        if not self._active:
10✔
870
            raise RuntimeError(
9✔
871
                "This task group is not active; no new tasks can be started."
872
            )
873

874
        kwargs = {}
10✔
875
        if task_status_future:
10✔
876
            parent_id = id(current_task())
10✔
877
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
878
                task_status_future, id(self.cancel_scope._host_task)
879
            )
880
        else:
881
            parent_id = id(self.cancel_scope._host_task)
10✔
882

883
        coro = func(*args, **kwargs)
10✔
884
        if not iscoroutine(coro):
10✔
885
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
9✔
886
            raise TypeError(
9✔
887
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
888
                f"the return value ({coro!r}) is not a coroutine object"
889
            )
890

891
        # Make the spawned task inherit the task group's cancel scope
892
        _task_states[coro] = task_state = TaskState(
10✔
893
            parent_id=parent_id, cancel_scope=self.cancel_scope
894
        )
895
        name = get_callable_name(func) if name is None else str(name)
10✔
896
        try:
10✔
897
            task = create_task(coro, name=name)
10✔
NEW
898
        except BaseException:
×
NEW
899
            del _task_states[coro]
×
NEW
900
            raise
×
901

902
        self.cancel_scope._tasks.add(task)
10✔
903
        self._tasks.add(task)
10✔
904

905
        del _task_states[coro]
10✔
906
        _task_states[task] = task_state
10✔
907
        if task.done():
10✔
908
            # This can happen with eager task factories
909
            task_done(task)
4✔
910
        else:
911
            task.add_done_callback(task_done)
10✔
912

913
        return task
10✔
914

915
    def start_soon(
10✔
916
        self,
917
        func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
918
        *args: Unpack[PosArgsT],
919
        name: object = None,
920
    ) -> None:
921
        self._spawn(func, args, name)
10✔
922

923
    async def start(
10✔
924
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
925
    ) -> Any:
926
        future: asyncio.Future = asyncio.Future()
10✔
927
        task = self._spawn(func, args, name, future)
10✔
928

929
        # If the task raises an exception after sending a start value without a switch
930
        # point between, the task group is cancelled and this method never proceeds to
931
        # process the completed future. That's why we have to have a shielded cancel
932
        # scope here.
933
        try:
10✔
934
            return await future
10✔
935
        except CancelledError:
9✔
936
            # Cancel the task and wait for it to exit before returning
937
            task.cancel()
9✔
938
            with CancelScope(shield=True), suppress(CancelledError):
9✔
939
                await task
9✔
940

941
            raise
9✔
942

943

944
#
945
# Threads
946
#
947

948
_Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]]
10✔
949

950

951
class WorkerThread(Thread):
10✔
952
    MAX_IDLE_TIME = 10  # seconds
10✔
953

954
    def __init__(
10✔
955
        self,
956
        root_task: asyncio.Task,
957
        workers: set[WorkerThread],
958
        idle_workers: deque[WorkerThread],
959
    ):
960
        super().__init__(name="AnyIO worker thread")
10✔
961
        self.root_task = root_task
10✔
962
        self.workers = workers
10✔
963
        self.idle_workers = idle_workers
10✔
964
        self.loop = root_task._loop
10✔
965
        self.queue: Queue[
10✔
966
            tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
967
        ] = Queue(2)
968
        self.idle_since = AsyncIOBackend.current_time()
10✔
969
        self.stopping = False
10✔
970

971
    def _report_result(
10✔
972
        self, future: asyncio.Future, result: Any, exc: BaseException | None
973
    ) -> None:
974
        self.idle_since = AsyncIOBackend.current_time()
10✔
975
        if not self.stopping:
10✔
976
            self.idle_workers.append(self)
10✔
977

978
        if not future.cancelled():
10✔
979
            if exc is not None:
10✔
980
                if isinstance(exc, StopIteration):
10✔
981
                    new_exc = RuntimeError("coroutine raised StopIteration")
9✔
982
                    new_exc.__cause__ = exc
9✔
983
                    exc = new_exc
9✔
984

985
                future.set_exception(exc)
10✔
986
            else:
987
                future.set_result(result)
10✔
988

989
    def run(self) -> None:
10✔
990
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
991
            while True:
7✔
992
                item = self.queue.get()
10✔
993
                if item is None:
10✔
994
                    # Shutdown command received
995
                    return
10✔
996

997
                context, func, args, future, cancel_scope = item
10✔
998
                if not future.cancelled():
10✔
999
                    result = None
10✔
1000
                    exception: BaseException | None = None
10✔
1001
                    threadlocals.current_cancel_scope = cancel_scope
10✔
1002
                    try:
10✔
1003
                        result = context.run(func, *args)
10✔
1004
                    except BaseException as exc:
10✔
1005
                        exception = exc
10✔
1006
                    finally:
1007
                        del threadlocals.current_cancel_scope
10✔
1008

1009
                    if not self.loop.is_closed():
10✔
1010
                        self.loop.call_soon_threadsafe(
10✔
1011
                            self._report_result, future, result, exception
1012
                        )
1013

1014
                self.queue.task_done()
10✔
1015

1016
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
1017
        self.stopping = True
10✔
1018
        self.queue.put_nowait(None)
10✔
1019
        self.workers.discard(self)
10✔
1020
        try:
10✔
1021
            self.idle_workers.remove(self)
10✔
1022
        except ValueError:
10✔
1023
            pass
10✔
1024

1025

1026
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
10✔
1027
    "_threadpool_idle_workers"
1028
)
1029
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
1030

1031

1032
class BlockingPortal(abc.BlockingPortal):
10✔
1033
    def __new__(cls) -> BlockingPortal:
10✔
1034
        return object.__new__(cls)
10✔
1035

1036
    def __init__(self) -> None:
10✔
1037
        super().__init__()
10✔
1038
        self._loop = get_running_loop()
10✔
1039

1040
    def _spawn_task_from_thread(
10✔
1041
        self,
1042
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
1043
        args: tuple[Unpack[PosArgsT]],
1044
        kwargs: dict[str, Any],
1045
        name: object,
1046
        future: Future[T_Retval],
1047
    ) -> None:
1048
        AsyncIOBackend.run_sync_from_thread(
10✔
1049
            partial(self._task_group.start_soon, name=name),
1050
            (self._call_func, func, args, kwargs, future),
1051
            self._loop,
1052
        )
1053

1054

1055
#
1056
# Subprocesses
1057
#
1058

1059

1060
@dataclass(eq=False)
10✔
1061
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
1062
    _stream: asyncio.StreamReader
10✔
1063

1064
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1065
        data = await self._stream.read(max_bytes)
9✔
1066
        if data:
9✔
1067
            return data
9✔
1068
        else:
1069
            raise EndOfStream
9✔
1070

1071
    async def aclose(self) -> None:
10✔
1072
        self._stream.set_exception(ClosedResourceError())
9✔
1073
        await AsyncIOBackend.checkpoint()
9✔
1074

1075

1076
@dataclass(eq=False)
10✔
1077
class StreamWriterWrapper(abc.ByteSendStream):
10✔
1078
    _stream: asyncio.StreamWriter
10✔
1079

1080
    async def send(self, item: bytes) -> None:
10✔
1081
        self._stream.write(item)
9✔
1082
        await self._stream.drain()
9✔
1083

1084
    async def aclose(self) -> None:
10✔
1085
        self._stream.close()
9✔
1086
        await AsyncIOBackend.checkpoint()
9✔
1087

1088

1089
@dataclass(eq=False)
10✔
1090
class Process(abc.Process):
10✔
1091
    _process: asyncio.subprocess.Process
10✔
1092
    _stdin: StreamWriterWrapper | None
10✔
1093
    _stdout: StreamReaderWrapper | None
10✔
1094
    _stderr: StreamReaderWrapper | None
10✔
1095

1096
    async def aclose(self) -> None:
10✔
1097
        with CancelScope(shield=True) as scope:
9✔
1098
            if self._stdin:
9✔
1099
                await self._stdin.aclose()
9✔
1100
            if self._stdout:
9✔
1101
                await self._stdout.aclose()
9✔
1102
            if self._stderr:
9✔
1103
                await self._stderr.aclose()
9✔
1104

1105
            scope.shield = False
9✔
1106
            try:
9✔
1107
                await self.wait()
9✔
1108
            except BaseException:
9✔
1109
                scope.shield = True
9✔
1110
                self.kill()
9✔
1111
                await self.wait()
9✔
1112
                raise
9✔
1113

1114
    async def wait(self) -> int:
10✔
1115
        return await self._process.wait()
9✔
1116

1117
    def terminate(self) -> None:
10✔
1118
        self._process.terminate()
9✔
1119

1120
    def kill(self) -> None:
10✔
1121
        self._process.kill()
9✔
1122

1123
    def send_signal(self, signal: int) -> None:
10✔
1124
        self._process.send_signal(signal)
×
1125

1126
    @property
10✔
1127
    def pid(self) -> int:
10✔
1128
        return self._process.pid
×
1129

1130
    @property
10✔
1131
    def returncode(self) -> int | None:
10✔
1132
        return self._process.returncode
9✔
1133

1134
    @property
10✔
1135
    def stdin(self) -> abc.ByteSendStream | None:
10✔
1136
        return self._stdin
9✔
1137

1138
    @property
10✔
1139
    def stdout(self) -> abc.ByteReceiveStream | None:
10✔
1140
        return self._stdout
9✔
1141

1142
    @property
10✔
1143
    def stderr(self) -> abc.ByteReceiveStream | None:
10✔
1144
        return self._stderr
9✔
1145

1146

1147
def _forcibly_shutdown_process_pool_on_exit(
10✔
1148
    workers: set[Process], _task: object
1149
) -> None:
1150
    """
1151
    Forcibly shuts down worker processes belonging to this event loop."""
1152
    child_watcher: asyncio.AbstractChildWatcher | None = None
9✔
1153
    if sys.version_info < (3, 12):
9✔
1154
        try:
5✔
1155
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
5✔
1156
        except NotImplementedError:
1✔
1157
            pass
1✔
1158

1159
    # Close as much as possible (w/o async/await) to avoid warnings
1160
    for process in workers:
9✔
1161
        if process.returncode is None:
9✔
1162
            continue
9✔
1163

1164
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
1165
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
1166
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
1167
        process.kill()
×
1168
        if child_watcher:
×
1169
            child_watcher.remove_child_handler(process.pid)
×
1170

1171

1172
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
1173
    """
1174
    Shuts down worker processes belonging to this event loop.
1175

1176
    NOTE: this only works when the event loop was started using asyncio.run() or
1177
    anyio.run().
1178

1179
    """
1180
    process: abc.Process
1181
    try:
9✔
1182
        await sleep(math.inf)
9✔
1183
    except asyncio.CancelledError:
9✔
1184
        for process in workers:
9✔
1185
            if process.returncode is None:
9✔
1186
                process.kill()
9✔
1187

1188
        for process in workers:
9✔
1189
            await process.aclose()
9✔
1190

1191

1192
#
1193
# Sockets and networking
1194
#
1195

1196

1197
class StreamProtocol(asyncio.Protocol):
10✔
1198
    read_queue: deque[bytes]
10✔
1199
    read_event: asyncio.Event
10✔
1200
    write_event: asyncio.Event
10✔
1201
    exception: Exception | None = None
10✔
1202
    is_at_eof: bool = False
10✔
1203

1204
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1205
        self.read_queue = deque()
10✔
1206
        self.read_event = asyncio.Event()
10✔
1207
        self.write_event = asyncio.Event()
10✔
1208
        self.write_event.set()
10✔
1209
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
1210

1211
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1212
        if exc:
10✔
1213
            self.exception = BrokenResourceError()
10✔
1214
            self.exception.__cause__ = exc
10✔
1215

1216
        self.read_event.set()
10✔
1217
        self.write_event.set()
10✔
1218

1219
    def data_received(self, data: bytes) -> None:
10✔
1220
        # ProactorEventloop sometimes sends bytearray instead of bytes
1221
        self.read_queue.append(bytes(data))
10✔
1222
        self.read_event.set()
10✔
1223

1224
    def eof_received(self) -> bool | None:
10✔
1225
        self.is_at_eof = True
10✔
1226
        self.read_event.set()
10✔
1227
        return True
10✔
1228

1229
    def pause_writing(self) -> None:
10✔
1230
        self.write_event = asyncio.Event()
10✔
1231

1232
    def resume_writing(self) -> None:
10✔
1233
        self.write_event.set()
1✔
1234

1235

1236
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
1237
    read_queue: deque[tuple[bytes, IPSockAddrType]]
10✔
1238
    read_event: asyncio.Event
10✔
1239
    write_event: asyncio.Event
10✔
1240
    exception: Exception | None = None
10✔
1241

1242
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1243
        self.read_queue = deque(maxlen=100)  # arbitrary value
9✔
1244
        self.read_event = asyncio.Event()
9✔
1245
        self.write_event = asyncio.Event()
9✔
1246
        self.write_event.set()
9✔
1247

1248
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1249
        self.read_event.set()
9✔
1250
        self.write_event.set()
9✔
1251

1252
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
1253
        addr = convert_ipv6_sockaddr(addr)
9✔
1254
        self.read_queue.append((data, addr))
9✔
1255
        self.read_event.set()
9✔
1256

1257
    def error_received(self, exc: Exception) -> None:
10✔
1258
        self.exception = exc
×
1259

1260
    def pause_writing(self) -> None:
10✔
1261
        self.write_event.clear()
×
1262

1263
    def resume_writing(self) -> None:
10✔
1264
        self.write_event.set()
×
1265

1266

1267
class SocketStream(abc.SocketStream):
10✔
1268
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
1269
        self._transport = transport
10✔
1270
        self._protocol = protocol
10✔
1271
        self._receive_guard = ResourceGuard("reading from")
10✔
1272
        self._send_guard = ResourceGuard("writing to")
10✔
1273
        self._closed = False
10✔
1274

1275
    @property
10✔
1276
    def _raw_socket(self) -> socket.socket:
10✔
1277
        return self._transport.get_extra_info("socket")
10✔
1278

1279
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1280
        with self._receive_guard:
10✔
1281
            if (
10✔
1282
                not self._protocol.read_event.is_set()
1283
                and not self._transport.is_closing()
1284
                and not self._protocol.is_at_eof
1285
            ):
1286
                self._transport.resume_reading()
10✔
1287
                await self._protocol.read_event.wait()
10✔
1288
                self._transport.pause_reading()
10✔
1289
            else:
1290
                await AsyncIOBackend.checkpoint()
10✔
1291

1292
            try:
10✔
1293
                chunk = self._protocol.read_queue.popleft()
10✔
1294
            except IndexError:
10✔
1295
                if self._closed:
10✔
1296
                    raise ClosedResourceError from None
10✔
1297
                elif self._protocol.exception:
10✔
1298
                    raise self._protocol.exception from None
10✔
1299
                else:
1300
                    raise EndOfStream from None
10✔
1301

1302
            if len(chunk) > max_bytes:
10✔
1303
                # Split the oversized chunk
1304
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
8✔
1305
                self._protocol.read_queue.appendleft(leftover)
8✔
1306

1307
            # If the read queue is empty, clear the flag so that the next call will
1308
            # block until data is available
1309
            if not self._protocol.read_queue:
10✔
1310
                self._protocol.read_event.clear()
10✔
1311

1312
        return chunk
10✔
1313

1314
    async def send(self, item: bytes) -> None:
10✔
1315
        with self._send_guard:
10✔
1316
            await AsyncIOBackend.checkpoint()
10✔
1317

1318
            if self._closed:
10✔
1319
                raise ClosedResourceError
10✔
1320
            elif self._protocol.exception is not None:
10✔
1321
                raise self._protocol.exception
10✔
1322

1323
            try:
10✔
1324
                self._transport.write(item)
10✔
1325
            except RuntimeError as exc:
×
1326
                if self._transport.is_closing():
×
1327
                    raise BrokenResourceError from exc
×
1328
                else:
1329
                    raise
×
1330

1331
            await self._protocol.write_event.wait()
10✔
1332

1333
    async def send_eof(self) -> None:
10✔
1334
        try:
10✔
1335
            self._transport.write_eof()
10✔
1336
        except OSError:
×
1337
            pass
×
1338

1339
    async def aclose(self) -> None:
10✔
1340
        if not self._transport.is_closing():
10✔
1341
            self._closed = True
10✔
1342
            try:
10✔
1343
                self._transport.write_eof()
10✔
1344
            except OSError:
6✔
1345
                pass
6✔
1346

1347
            self._transport.close()
10✔
1348
            await sleep(0)
10✔
1349
            self._transport.abort()
10✔
1350

1351

1352
class _RawSocketMixin:
10✔
1353
    _receive_future: asyncio.Future | None = None
10✔
1354
    _send_future: asyncio.Future | None = None
10✔
1355
    _closing = False
10✔
1356

1357
    def __init__(self, raw_socket: socket.socket):
10✔
1358
        self.__raw_socket = raw_socket
7✔
1359
        self._receive_guard = ResourceGuard("reading from")
7✔
1360
        self._send_guard = ResourceGuard("writing to")
7✔
1361

1362
    @property
10✔
1363
    def _raw_socket(self) -> socket.socket:
10✔
1364
        return self.__raw_socket
7✔
1365

1366
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1367
        def callback(f: object) -> None:
7✔
1368
            del self._receive_future
7✔
1369
            loop.remove_reader(self.__raw_socket)
7✔
1370

1371
        f = self._receive_future = asyncio.Future()
7✔
1372
        loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1373
        f.add_done_callback(callback)
7✔
1374
        return f
7✔
1375

1376
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1377
        def callback(f: object) -> None:
7✔
1378
            del self._send_future
7✔
1379
            loop.remove_writer(self.__raw_socket)
7✔
1380

1381
        f = self._send_future = asyncio.Future()
7✔
1382
        loop.add_writer(self.__raw_socket, f.set_result, None)
7✔
1383
        f.add_done_callback(callback)
7✔
1384
        return f
7✔
1385

1386
    async def aclose(self) -> None:
10✔
1387
        if not self._closing:
7✔
1388
            self._closing = True
7✔
1389
            if self.__raw_socket.fileno() != -1:
7✔
1390
                self.__raw_socket.close()
7✔
1391

1392
            if self._receive_future:
7✔
1393
                self._receive_future.set_result(None)
7✔
1394
            if self._send_future:
7✔
1395
                self._send_future.set_result(None)
×
1396

1397

1398
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1399
    async def send_eof(self) -> None:
10✔
1400
        with self._send_guard:
7✔
1401
            self._raw_socket.shutdown(socket.SHUT_WR)
7✔
1402

1403
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1404
        loop = get_running_loop()
7✔
1405
        await AsyncIOBackend.checkpoint()
7✔
1406
        with self._receive_guard:
7✔
1407
            while True:
5✔
1408
                try:
7✔
1409
                    data = self._raw_socket.recv(max_bytes)
7✔
1410
                except BlockingIOError:
7✔
1411
                    await self._wait_until_readable(loop)
7✔
1412
                except OSError as exc:
7✔
1413
                    if self._closing:
7✔
1414
                        raise ClosedResourceError from None
7✔
1415
                    else:
1416
                        raise BrokenResourceError from exc
2✔
1417
                else:
1418
                    if not data:
7✔
1419
                        raise EndOfStream
7✔
1420

1421
                    return data
7✔
1422

1423
    async def send(self, item: bytes) -> None:
10✔
1424
        loop = get_running_loop()
7✔
1425
        await AsyncIOBackend.checkpoint()
7✔
1426
        with self._send_guard:
7✔
1427
            view = memoryview(item)
7✔
1428
            while view:
7✔
1429
                try:
7✔
1430
                    bytes_sent = self._raw_socket.send(view)
7✔
1431
                except BlockingIOError:
7✔
1432
                    await self._wait_until_writable(loop)
7✔
1433
                except OSError as exc:
7✔
1434
                    if self._closing:
7✔
1435
                        raise ClosedResourceError from None
7✔
1436
                    else:
1437
                        raise BrokenResourceError from exc
2✔
1438
                else:
1439
                    view = view[bytes_sent:]
7✔
1440

1441
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1442
        if not isinstance(msglen, int) or msglen < 0:
7✔
1443
            raise ValueError("msglen must be a non-negative integer")
7✔
1444
        if not isinstance(maxfds, int) or maxfds < 1:
7✔
1445
            raise ValueError("maxfds must be a positive integer")
7✔
1446

1447
        loop = get_running_loop()
7✔
1448
        fds = array.array("i")
7✔
1449
        await AsyncIOBackend.checkpoint()
7✔
1450
        with self._receive_guard:
7✔
1451
            while True:
5✔
1452
                try:
7✔
1453
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
7✔
1454
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1455
                    )
1456
                except BlockingIOError:
7✔
1457
                    await self._wait_until_readable(loop)
7✔
1458
                except OSError as exc:
×
1459
                    if self._closing:
×
1460
                        raise ClosedResourceError from None
×
1461
                    else:
1462
                        raise BrokenResourceError from exc
×
1463
                else:
1464
                    if not message and not ancdata:
7✔
1465
                        raise EndOfStream
×
1466

1467
                    break
5✔
1468

1469
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
7✔
1470
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
7✔
1471
                raise RuntimeError(
×
1472
                    f"Received unexpected ancillary data; message = {message!r}, "
1473
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1474
                )
1475

1476
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
7✔
1477

1478
        return message, list(fds)
7✔
1479

1480
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1481
        if not message:
7✔
1482
            raise ValueError("message must not be empty")
7✔
1483
        if not fds:
7✔
1484
            raise ValueError("fds must not be empty")
7✔
1485

1486
        loop = get_running_loop()
7✔
1487
        filenos: list[int] = []
7✔
1488
        for fd in fds:
7✔
1489
            if isinstance(fd, int):
7✔
1490
                filenos.append(fd)
×
1491
            elif isinstance(fd, IOBase):
7✔
1492
                filenos.append(fd.fileno())
7✔
1493

1494
        fdarray = array.array("i", filenos)
7✔
1495
        await AsyncIOBackend.checkpoint()
7✔
1496
        with self._send_guard:
7✔
1497
            while True:
5✔
1498
                try:
7✔
1499
                    # The ignore can be removed after mypy picks up
1500
                    # https://github.com/python/typeshed/pull/5545
1501
                    self._raw_socket.sendmsg(
7✔
1502
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1503
                    )
1504
                    break
7✔
1505
                except BlockingIOError:
×
1506
                    await self._wait_until_writable(loop)
×
1507
                except OSError as exc:
×
1508
                    if self._closing:
×
1509
                        raise ClosedResourceError from None
×
1510
                    else:
1511
                        raise BrokenResourceError from exc
×
1512

1513

1514
class TCPSocketListener(abc.SocketListener):
10✔
1515
    _accept_scope: CancelScope | None = None
10✔
1516
    _closed = False
10✔
1517

1518
    def __init__(self, raw_socket: socket.socket):
10✔
1519
        self.__raw_socket = raw_socket
10✔
1520
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1521
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1522

1523
    @property
10✔
1524
    def _raw_socket(self) -> socket.socket:
10✔
1525
        return self.__raw_socket
10✔
1526

1527
    async def accept(self) -> abc.SocketStream:
10✔
1528
        if self._closed:
10✔
1529
            raise ClosedResourceError
10✔
1530

1531
        with self._accept_guard:
10✔
1532
            await AsyncIOBackend.checkpoint()
10✔
1533
            with CancelScope() as self._accept_scope:
10✔
1534
                try:
10✔
1535
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1536
                except asyncio.CancelledError:
10✔
1537
                    # Workaround for https://bugs.python.org/issue41317
1538
                    try:
10✔
1539
                        self._loop.remove_reader(self._raw_socket)
10✔
1540
                    except (ValueError, NotImplementedError):
2✔
1541
                        pass
2✔
1542

1543
                    if self._closed:
10✔
1544
                        raise ClosedResourceError from None
9✔
1545

1546
                    raise
10✔
1547
                finally:
1548
                    self._accept_scope = None
10✔
1549

1550
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1551
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1552
            StreamProtocol, client_sock
1553
        )
1554
        return SocketStream(transport, protocol)
10✔
1555

1556
    async def aclose(self) -> None:
10✔
1557
        if self._closed:
10✔
1558
            return
10✔
1559

1560
        self._closed = True
10✔
1561
        if self._accept_scope:
10✔
1562
            # Workaround for https://bugs.python.org/issue41317
1563
            try:
10✔
1564
                self._loop.remove_reader(self._raw_socket)
10✔
1565
            except (ValueError, NotImplementedError):
2✔
1566
                pass
2✔
1567

1568
            self._accept_scope.cancel()
9✔
1569
            await sleep(0)
9✔
1570

1571
        self._raw_socket.close()
10✔
1572

1573

1574
class UNIXSocketListener(abc.SocketListener):
10✔
1575
    def __init__(self, raw_socket: socket.socket):
10✔
1576
        self.__raw_socket = raw_socket
7✔
1577
        self._loop = get_running_loop()
7✔
1578
        self._accept_guard = ResourceGuard("accepting connections from")
7✔
1579
        self._closed = False
7✔
1580

1581
    async def accept(self) -> abc.SocketStream:
10✔
1582
        await AsyncIOBackend.checkpoint()
7✔
1583
        with self._accept_guard:
7✔
1584
            while True:
5✔
1585
                try:
7✔
1586
                    client_sock, _ = self.__raw_socket.accept()
7✔
1587
                    client_sock.setblocking(False)
7✔
1588
                    return UNIXSocketStream(client_sock)
7✔
1589
                except BlockingIOError:
7✔
1590
                    f: asyncio.Future = asyncio.Future()
7✔
1591
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1592
                    f.add_done_callback(
7✔
1593
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1594
                    )
1595
                    await f
7✔
1596
                except OSError as exc:
×
1597
                    if self._closed:
×
1598
                        raise ClosedResourceError from None
×
1599
                    else:
1600
                        raise BrokenResourceError from exc
2✔
1601

1602
    async def aclose(self) -> None:
10✔
1603
        self._closed = True
7✔
1604
        self.__raw_socket.close()
7✔
1605

1606
    @property
10✔
1607
    def _raw_socket(self) -> socket.socket:
10✔
1608
        return self.__raw_socket
7✔
1609

1610

1611
class UDPSocket(abc.UDPSocket):
10✔
1612
    def __init__(
10✔
1613
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1614
    ):
1615
        self._transport = transport
9✔
1616
        self._protocol = protocol
9✔
1617
        self._receive_guard = ResourceGuard("reading from")
9✔
1618
        self._send_guard = ResourceGuard("writing to")
9✔
1619
        self._closed = False
9✔
1620

1621
    @property
10✔
1622
    def _raw_socket(self) -> socket.socket:
10✔
1623
        return self._transport.get_extra_info("socket")
9✔
1624

1625
    async def aclose(self) -> None:
10✔
1626
        if not self._transport.is_closing():
9✔
1627
            self._closed = True
9✔
1628
            self._transport.close()
9✔
1629

1630
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1631
        with self._receive_guard:
9✔
1632
            await AsyncIOBackend.checkpoint()
9✔
1633

1634
            # If the buffer is empty, ask for more data
1635
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1636
                self._protocol.read_event.clear()
9✔
1637
                await self._protocol.read_event.wait()
9✔
1638

1639
            try:
9✔
1640
                return self._protocol.read_queue.popleft()
9✔
1641
            except IndexError:
9✔
1642
                if self._closed:
9✔
1643
                    raise ClosedResourceError from None
9✔
1644
                else:
1645
                    raise BrokenResourceError from None
3✔
1646

1647
    async def send(self, item: UDPPacketType) -> None:
10✔
1648
        with self._send_guard:
9✔
1649
            await AsyncIOBackend.checkpoint()
9✔
1650
            await self._protocol.write_event.wait()
9✔
1651
            if self._closed:
9✔
1652
                raise ClosedResourceError
9✔
1653
            elif self._transport.is_closing():
9✔
1654
                raise BrokenResourceError
×
1655
            else:
1656
                self._transport.sendto(*item)
9✔
1657

1658

1659
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1660
    def __init__(
10✔
1661
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1662
    ):
1663
        self._transport = transport
9✔
1664
        self._protocol = protocol
9✔
1665
        self._receive_guard = ResourceGuard("reading from")
9✔
1666
        self._send_guard = ResourceGuard("writing to")
9✔
1667
        self._closed = False
9✔
1668

1669
    @property
10✔
1670
    def _raw_socket(self) -> socket.socket:
10✔
1671
        return self._transport.get_extra_info("socket")
9✔
1672

1673
    async def aclose(self) -> None:
10✔
1674
        if not self._transport.is_closing():
9✔
1675
            self._closed = True
9✔
1676
            self._transport.close()
9✔
1677

1678
    async def receive(self) -> bytes:
10✔
1679
        with self._receive_guard:
9✔
1680
            await AsyncIOBackend.checkpoint()
9✔
1681

1682
            # If the buffer is empty, ask for more data
1683
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1684
                self._protocol.read_event.clear()
9✔
1685
                await self._protocol.read_event.wait()
9✔
1686

1687
            try:
9✔
1688
                packet = self._protocol.read_queue.popleft()
9✔
1689
            except IndexError:
9✔
1690
                if self._closed:
9✔
1691
                    raise ClosedResourceError from None
9✔
1692
                else:
1693
                    raise BrokenResourceError from None
×
1694

1695
            return packet[0]
9✔
1696

1697
    async def send(self, item: bytes) -> None:
10✔
1698
        with self._send_guard:
9✔
1699
            await AsyncIOBackend.checkpoint()
9✔
1700
            await self._protocol.write_event.wait()
9✔
1701
            if self._closed:
9✔
1702
                raise ClosedResourceError
9✔
1703
            elif self._transport.is_closing():
9✔
1704
                raise BrokenResourceError
×
1705
            else:
1706
                self._transport.sendto(item)
9✔
1707

1708

1709
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1710
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1711
        loop = get_running_loop()
7✔
1712
        await AsyncIOBackend.checkpoint()
7✔
1713
        with self._receive_guard:
7✔
1714
            while True:
5✔
1715
                try:
7✔
1716
                    data = self._raw_socket.recvfrom(65536)
7✔
1717
                except BlockingIOError:
7✔
1718
                    await self._wait_until_readable(loop)
7✔
1719
                except OSError as exc:
7✔
1720
                    if self._closing:
7✔
1721
                        raise ClosedResourceError from None
7✔
1722
                    else:
1723
                        raise BrokenResourceError from exc
2✔
1724
                else:
1725
                    return data
7✔
1726

1727
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1728
        loop = get_running_loop()
7✔
1729
        await AsyncIOBackend.checkpoint()
7✔
1730
        with self._send_guard:
7✔
1731
            while True:
5✔
1732
                try:
7✔
1733
                    self._raw_socket.sendto(*item)
7✔
1734
                except BlockingIOError:
7✔
1735
                    await self._wait_until_writable(loop)
×
1736
                except OSError as exc:
7✔
1737
                    if self._closing:
7✔
1738
                        raise ClosedResourceError from None
7✔
1739
                    else:
1740
                        raise BrokenResourceError from exc
2✔
1741
                else:
1742
                    return
7✔
1743

1744

1745
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1746
    async def receive(self) -> bytes:
10✔
1747
        loop = get_running_loop()
7✔
1748
        await AsyncIOBackend.checkpoint()
7✔
1749
        with self._receive_guard:
7✔
1750
            while True:
5✔
1751
                try:
7✔
1752
                    data = self._raw_socket.recv(65536)
7✔
1753
                except BlockingIOError:
7✔
1754
                    await self._wait_until_readable(loop)
7✔
1755
                except OSError as exc:
7✔
1756
                    if self._closing:
7✔
1757
                        raise ClosedResourceError from None
7✔
1758
                    else:
1759
                        raise BrokenResourceError from exc
2✔
1760
                else:
1761
                    return data
7✔
1762

1763
    async def send(self, item: bytes) -> None:
10✔
1764
        loop = get_running_loop()
7✔
1765
        await AsyncIOBackend.checkpoint()
7✔
1766
        with self._send_guard:
7✔
1767
            while True:
5✔
1768
                try:
7✔
1769
                    self._raw_socket.send(item)
7✔
1770
                except BlockingIOError:
7✔
1771
                    await self._wait_until_writable(loop)
×
1772
                except OSError as exc:
7✔
1773
                    if self._closing:
7✔
1774
                        raise ClosedResourceError from None
7✔
1775
                    else:
1776
                        raise BrokenResourceError from exc
2✔
1777
                else:
1778
                    return
7✔
1779

1780

1781
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1782
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1783

1784

1785
#
1786
# Synchronization
1787
#
1788

1789

1790
class Event(BaseEvent):
10✔
1791
    def __new__(cls) -> Event:
10✔
1792
        return object.__new__(cls)
10✔
1793

1794
    def __init__(self) -> None:
10✔
1795
        self._event = asyncio.Event()
10✔
1796

1797
    def set(self) -> None:
10✔
1798
        self._event.set()
10✔
1799

1800
    def is_set(self) -> bool:
10✔
1801
        return self._event.is_set()
10✔
1802

1803
    async def wait(self) -> None:
10✔
1804
        if self.is_set():
10✔
1805
            await AsyncIOBackend.checkpoint()
10✔
1806
        else:
1807
            await self._event.wait()
10✔
1808

1809
    def statistics(self) -> EventStatistics:
10✔
1810
        return EventStatistics(len(self._event._waiters))
9✔
1811

1812

1813
class Lock(BaseLock):
10✔
1814
    def __new__(cls, *, fast_acquire: bool = False) -> Lock:
10✔
1815
        return object.__new__(cls)
9✔
1816

1817
    def __init__(self, *, fast_acquire: bool = False) -> None:
10✔
1818
        self._fast_acquire = fast_acquire
9✔
1819
        self._owner_task: asyncio.Task | None = None
9✔
1820
        self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
9✔
1821

1822
    async def acquire(self) -> None:
10✔
1823
        task = cast(asyncio.Task, current_task())
9✔
1824
        if self._owner_task is None and not self._waiters:
9✔
1825
            await AsyncIOBackend.checkpoint_if_cancelled()
9✔
1826
            self._owner_task = task
9✔
1827

1828
            # Unless on the "fast path", yield control of the event loop so that other
1829
            # tasks can run too
1830
            if not self._fast_acquire:
9✔
1831
                try:
9✔
1832
                    await AsyncIOBackend.cancel_shielded_checkpoint()
9✔
1833
                except CancelledError:
9✔
1834
                    self.release()
9✔
1835
                    raise
9✔
1836

1837
            return
9✔
1838

1839
        if self._owner_task == task:
9✔
1840
            raise RuntimeError("Attempted to acquire an already held Lock")
9✔
1841

1842
        fut: asyncio.Future[None] = asyncio.Future()
9✔
1843
        item = task, fut
9✔
1844
        self._waiters.append(item)
9✔
1845
        try:
9✔
1846
            await fut
9✔
1847
        except CancelledError:
9✔
1848
            self._waiters.remove(item)
9✔
1849
            if self._owner_task is task:
9✔
1850
                self.release()
9✔
1851

1852
            raise
9✔
1853

1854
        self._waiters.remove(item)
9✔
1855

1856
    def acquire_nowait(self) -> None:
10✔
1857
        task = cast(asyncio.Task, current_task())
9✔
1858
        if self._owner_task is None and not self._waiters:
9✔
1859
            self._owner_task = task
9✔
1860
            return
9✔
1861

1862
        if self._owner_task is task:
9✔
1863
            raise RuntimeError("Attempted to acquire an already held Lock")
9✔
1864

1865
        raise WouldBlock
9✔
1866

1867
    def locked(self) -> bool:
10✔
1868
        return self._owner_task is not None
9✔
1869

1870
    def release(self) -> None:
10✔
1871
        if self._owner_task != current_task():
9✔
1872
            raise RuntimeError("The current task is not holding this lock")
×
1873

1874
        for task, fut in self._waiters:
9✔
1875
            if not fut.cancelled():
9✔
1876
                self._owner_task = task
9✔
1877
                fut.set_result(None)
9✔
1878
                return
9✔
1879

1880
        self._owner_task = None
9✔
1881

1882
    def statistics(self) -> LockStatistics:
10✔
1883
        task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
9✔
1884
        return LockStatistics(self.locked(), task_info, len(self._waiters))
9✔
1885

1886

1887
class Semaphore(BaseSemaphore):
10✔
1888
    def __new__(
10✔
1889
        cls,
1890
        initial_value: int,
1891
        *,
1892
        max_value: int | None = None,
1893
        fast_acquire: bool = False,
1894
    ) -> Semaphore:
1895
        return object.__new__(cls)
9✔
1896

1897
    def __init__(
10✔
1898
        self,
1899
        initial_value: int,
1900
        *,
1901
        max_value: int | None = None,
1902
        fast_acquire: bool = False,
1903
    ):
1904
        super().__init__(initial_value, max_value=max_value)
9✔
1905
        self._value = initial_value
9✔
1906
        self._max_value = max_value
9✔
1907
        self._fast_acquire = fast_acquire
9✔
1908
        self._waiters: deque[asyncio.Future[None]] = deque()
9✔
1909

1910
    async def acquire(self) -> None:
10✔
1911
        if self._value > 0 and not self._waiters:
9✔
1912
            await AsyncIOBackend.checkpoint_if_cancelled()
9✔
1913
            self._value -= 1
9✔
1914

1915
            # Unless on the "fast path", yield control of the event loop so that other
1916
            # tasks can run too
1917
            if not self._fast_acquire:
9✔
1918
                try:
9✔
1919
                    await AsyncIOBackend.cancel_shielded_checkpoint()
9✔
1920
                except CancelledError:
9✔
1921
                    self.release()
9✔
1922
                    raise
9✔
1923

1924
            return
9✔
1925

1926
        fut: asyncio.Future[None] = asyncio.Future()
9✔
1927
        self._waiters.append(fut)
9✔
1928
        try:
9✔
1929
            await fut
9✔
1930
        except CancelledError:
9✔
1931
            try:
9✔
1932
                self._waiters.remove(fut)
9✔
1933
            except ValueError:
9✔
1934
                self.release()
9✔
1935

1936
            raise
9✔
1937

1938
    def acquire_nowait(self) -> None:
10✔
1939
        if self._value == 0:
9✔
1940
            raise WouldBlock
9✔
1941

1942
        self._value -= 1
9✔
1943

1944
    def release(self) -> None:
10✔
1945
        if self._max_value is not None and self._value == self._max_value:
9✔
1946
            raise ValueError("semaphore released too many times")
9✔
1947

1948
        for fut in self._waiters:
9✔
1949
            if not fut.cancelled():
9✔
1950
                fut.set_result(None)
9✔
1951
                self._waiters.remove(fut)
9✔
1952
                return
9✔
1953

1954
        self._value += 1
9✔
1955

1956
    @property
10✔
1957
    def value(self) -> int:
10✔
1958
        return self._value
9✔
1959

1960
    @property
10✔
1961
    def max_value(self) -> int | None:
10✔
1962
        return self._max_value
9✔
1963

1964
    def statistics(self) -> SemaphoreStatistics:
10✔
1965
        return SemaphoreStatistics(len(self._waiters))
9✔
1966

1967

1968
class CapacityLimiter(BaseCapacityLimiter):
10✔
1969
    _total_tokens: float = 0
10✔
1970

1971
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1972
        return object.__new__(cls)
10✔
1973

1974
    def __init__(self, total_tokens: float):
10✔
1975
        self._borrowers: set[Any] = set()
10✔
1976
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1977
        self.total_tokens = total_tokens
10✔
1978

1979
    async def __aenter__(self) -> None:
10✔
1980
        await self.acquire()
10✔
1981

1982
    async def __aexit__(
10✔
1983
        self,
1984
        exc_type: type[BaseException] | None,
1985
        exc_val: BaseException | None,
1986
        exc_tb: TracebackType | None,
1987
    ) -> None:
1988
        self.release()
10✔
1989

1990
    @property
10✔
1991
    def total_tokens(self) -> float:
10✔
1992
        return self._total_tokens
9✔
1993

1994
    @total_tokens.setter
10✔
1995
    def total_tokens(self, value: float) -> None:
10✔
1996
        if not isinstance(value, int) and not math.isinf(value):
10✔
1997
            raise TypeError("total_tokens must be an int or math.inf")
9✔
1998
        if value < 1:
10✔
1999
            raise ValueError("total_tokens must be >= 1")
9✔
2000

2001
        waiters_to_notify = max(value - self._total_tokens, 0)
10✔
2002
        self._total_tokens = value
10✔
2003

2004
        # Notify waiting tasks that they have acquired the limiter
2005
        while self._wait_queue and waiters_to_notify:
10✔
2006
            event = self._wait_queue.popitem(last=False)[1]
9✔
2007
            event.set()
9✔
2008
            waiters_to_notify -= 1
9✔
2009

2010
    @property
10✔
2011
    def borrowed_tokens(self) -> int:
10✔
2012
        return len(self._borrowers)
9✔
2013

2014
    @property
10✔
2015
    def available_tokens(self) -> float:
10✔
2016
        return self._total_tokens - len(self._borrowers)
9✔
2017

2018
    def acquire_nowait(self) -> None:
10✔
2019
        self.acquire_on_behalf_of_nowait(current_task())
×
2020

2021
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
2022
        if borrower in self._borrowers:
10✔
2023
            raise RuntimeError(
9✔
2024
                "this borrower is already holding one of this CapacityLimiter's "
2025
                "tokens"
2026
            )
2027

2028
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
2029
            raise WouldBlock
9✔
2030

2031
        self._borrowers.add(borrower)
10✔
2032

2033
    async def acquire(self) -> None:
10✔
2034
        return await self.acquire_on_behalf_of(current_task())
10✔
2035

2036
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
2037
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
2038
        try:
10✔
2039
            self.acquire_on_behalf_of_nowait(borrower)
10✔
2040
        except WouldBlock:
9✔
2041
            event = asyncio.Event()
9✔
2042
            self._wait_queue[borrower] = event
9✔
2043
            try:
9✔
2044
                await event.wait()
9✔
2045
            except BaseException:
×
2046
                self._wait_queue.pop(borrower, None)
×
2047
                raise
×
2048

2049
            self._borrowers.add(borrower)
9✔
2050
        else:
2051
            try:
10✔
2052
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
2053
            except BaseException:
9✔
2054
                self.release()
9✔
2055
                raise
9✔
2056

2057
    def release(self) -> None:
10✔
2058
        self.release_on_behalf_of(current_task())
10✔
2059

2060
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
2061
        try:
10✔
2062
            self._borrowers.remove(borrower)
10✔
2063
        except KeyError:
9✔
2064
            raise RuntimeError(
9✔
2065
                "this borrower isn't holding any of this CapacityLimiter's tokens"
2066
            ) from None
2067

2068
        # Notify the next task in line if this limiter has free capacity now
2069
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
2070
            event = self._wait_queue.popitem(last=False)[1]
9✔
2071
            event.set()
9✔
2072

2073
    def statistics(self) -> CapacityLimiterStatistics:
10✔
2074
        return CapacityLimiterStatistics(
9✔
2075
            self.borrowed_tokens,
2076
            self.total_tokens,
2077
            tuple(self._borrowers),
2078
            len(self._wait_queue),
2079
        )
2080

2081

2082
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
2083

2084

2085
#
2086
# Operating system signals
2087
#
2088

2089

2090
class _SignalReceiver:
10✔
2091
    def __init__(self, signals: tuple[Signals, ...]):
10✔
2092
        self._signals = signals
8✔
2093
        self._loop = get_running_loop()
8✔
2094
        self._signal_queue: deque[Signals] = deque()
8✔
2095
        self._future: asyncio.Future = asyncio.Future()
8✔
2096
        self._handled_signals: set[Signals] = set()
8✔
2097

2098
    def _deliver(self, signum: Signals) -> None:
10✔
2099
        self._signal_queue.append(signum)
8✔
2100
        if not self._future.done():
8✔
2101
            self._future.set_result(None)
8✔
2102

2103
    def __enter__(self) -> _SignalReceiver:
10✔
2104
        for sig in set(self._signals):
8✔
2105
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
2106
            self._handled_signals.add(sig)
8✔
2107

2108
        return self
8✔
2109

2110
    def __exit__(
10✔
2111
        self,
2112
        exc_type: type[BaseException] | None,
2113
        exc_val: BaseException | None,
2114
        exc_tb: TracebackType | None,
2115
    ) -> bool | None:
2116
        for sig in self._handled_signals:
8✔
2117
            self._loop.remove_signal_handler(sig)
8✔
2118
        return None
8✔
2119

2120
    def __aiter__(self) -> _SignalReceiver:
10✔
2121
        return self
8✔
2122

2123
    async def __anext__(self) -> Signals:
10✔
2124
        await AsyncIOBackend.checkpoint()
8✔
2125
        if not self._signal_queue:
8✔
2126
            self._future = asyncio.Future()
×
2127
            await self._future
×
2128

2129
        return self._signal_queue.popleft()
8✔
2130

2131

2132
#
2133
# Testing and debugging
2134
#
2135

2136

2137
class AsyncIOTaskInfo(TaskInfo):
10✔
2138
    def __init__(self, task: asyncio.Task):
10✔
2139
        task_state = _task_states.get(task)
10✔
2140
        if task_state is None:
10✔
2141
            parent_id = None
10✔
2142
        else:
2143
            parent_id = task_state.parent_id
10✔
2144

2145
        super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
10✔
2146
        self._task = weakref.ref(task)
10✔
2147

2148
    def has_pending_cancellation(self) -> bool:
10✔
2149
        if not (task := self._task()):
10✔
2150
            # If the task isn't around anymore, it won't have a pending cancellation
2151
            return False
×
2152

2153
        if sys.version_info >= (3, 11):
10✔
2154
            if task.cancelling():
5✔
2155
                return True
5✔
2156
        elif (
5✔
2157
            isinstance(task._fut_waiter, asyncio.Future)
2158
            and task._fut_waiter.cancelled()
2159
        ):
2160
            return True
5✔
2161

2162
        if task_state := _task_states.get(task):
10✔
2163
            if cancel_scope := task_state.cancel_scope:
10✔
2164
                return cancel_scope._effectively_cancelled
10✔
2165

2166
        return False
10✔
2167

2168

2169
class TestRunner(abc.TestRunner):
10✔
2170
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
2171

2172
    def __init__(
10✔
2173
        self,
2174
        *,
2175
        debug: bool | None = None,
2176
        use_uvloop: bool = False,
2177
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
2178
    ) -> None:
2179
        if use_uvloop and loop_factory is None:
10✔
2180
            import uvloop
×
2181

2182
            loop_factory = uvloop.new_event_loop
×
2183

2184
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
10✔
2185
        self._exceptions: list[BaseException] = []
10✔
2186
        self._runner_task: asyncio.Task | None = None
10✔
2187

2188
    def __enter__(self) -> TestRunner:
10✔
2189
        self._runner.__enter__()
10✔
2190
        self.get_loop().set_exception_handler(self._exception_handler)
10✔
2191
        return self
10✔
2192

2193
    def __exit__(
10✔
2194
        self,
2195
        exc_type: type[BaseException] | None,
2196
        exc_val: BaseException | None,
2197
        exc_tb: TracebackType | None,
2198
    ) -> None:
2199
        self._runner.__exit__(exc_type, exc_val, exc_tb)
10✔
2200

2201
    def get_loop(self) -> AbstractEventLoop:
10✔
2202
        return self._runner.get_loop()
10✔
2203

2204
    def _exception_handler(
10✔
2205
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
2206
    ) -> None:
2207
        if isinstance(context.get("exception"), Exception):
10✔
2208
            self._exceptions.append(context["exception"])
10✔
2209
        else:
2210
            loop.default_exception_handler(context)
10✔
2211

2212
    def _raise_async_exceptions(self) -> None:
10✔
2213
        # Re-raise any exceptions raised in asynchronous callbacks
2214
        if self._exceptions:
10✔
2215
            exceptions, self._exceptions = self._exceptions, []
10✔
2216
            if len(exceptions) == 1:
10✔
2217
                raise exceptions[0]
10✔
2218
            elif exceptions:
×
2219
                raise BaseExceptionGroup(
×
2220
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
2221
                )
2222

2223
    async def _run_tests_and_fixtures(
10✔
2224
        self,
2225
        receive_stream: MemoryObjectReceiveStream[
2226
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
2227
        ],
2228
    ) -> None:
2229
        from _pytest.outcomes import OutcomeException
10✔
2230

2231
        with receive_stream, self._send_stream:
10✔
2232
            async for coro, future in receive_stream:
10✔
2233
                try:
10✔
2234
                    retval = await coro
10✔
2235
                except CancelledError as exc:
10✔
2236
                    if not future.cancelled():
×
2237
                        future.cancel(*exc.args)
×
2238

2239
                    raise
×
2240
                except BaseException as exc:
10✔
2241
                    if not future.cancelled():
10✔
2242
                        future.set_exception(exc)
10✔
2243

2244
                    if not isinstance(exc, (Exception, OutcomeException)):
10✔
2245
                        raise
×
2246
                else:
2247
                    if not future.cancelled():
10✔
2248
                        future.set_result(retval)
10✔
2249

2250
    async def _call_in_runner_task(
10✔
2251
        self,
2252
        func: Callable[P, Awaitable[T_Retval]],
2253
        *args: P.args,
2254
        **kwargs: P.kwargs,
2255
    ) -> T_Retval:
2256
        if not self._runner_task:
10✔
2257
            self._send_stream, receive_stream = create_memory_object_stream[
10✔
2258
                tuple[Awaitable[Any], asyncio.Future]
2259
            ](1)
2260
            self._runner_task = self.get_loop().create_task(
10✔
2261
                self._run_tests_and_fixtures(receive_stream)
2262
            )
2263

2264
        coro = func(*args, **kwargs)
10✔
2265
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
10✔
2266
        self._send_stream.send_nowait((coro, future))
10✔
2267
        return await future
10✔
2268

2269
    def run_asyncgen_fixture(
10✔
2270
        self,
2271
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
2272
        kwargs: dict[str, Any],
2273
    ) -> Iterable[T_Retval]:
2274
        asyncgen = fixture_func(**kwargs)
10✔
2275
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
10✔
2276
            self._call_in_runner_task(asyncgen.asend, None)
2277
        )
2278
        self._raise_async_exceptions()
10✔
2279

2280
        yield fixturevalue
10✔
2281

2282
        try:
10✔
2283
            self.get_loop().run_until_complete(
10✔
2284
                self._call_in_runner_task(asyncgen.asend, None)
2285
            )
2286
        except StopAsyncIteration:
10✔
2287
            self._raise_async_exceptions()
10✔
2288
        else:
2289
            self.get_loop().run_until_complete(asyncgen.aclose())
×
2290
            raise RuntimeError("Async generator fixture did not stop")
×
2291

2292
    def run_fixture(
10✔
2293
        self,
2294
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
2295
        kwargs: dict[str, Any],
2296
    ) -> T_Retval:
2297
        retval = self.get_loop().run_until_complete(
10✔
2298
            self._call_in_runner_task(fixture_func, **kwargs)
2299
        )
2300
        self._raise_async_exceptions()
10✔
2301
        return retval
10✔
2302

2303
    def run_test(
10✔
2304
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
2305
    ) -> None:
2306
        try:
10✔
2307
            self.get_loop().run_until_complete(
10✔
2308
                self._call_in_runner_task(test_func, **kwargs)
2309
            )
2310
        except Exception as exc:
10✔
2311
            self._exceptions.append(exc)
10✔
2312

2313
        self._raise_async_exceptions()
10✔
2314

2315

2316
class AsyncIOBackend(AsyncBackend):
10✔
2317
    @classmethod
10✔
2318
    def run(
10✔
2319
        cls,
2320
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
2321
        args: tuple[Unpack[PosArgsT]],
2322
        kwargs: dict[str, Any],
2323
        options: dict[str, Any],
2324
    ) -> T_Retval:
2325
        @wraps(func)
10✔
2326
        async def wrapper() -> T_Retval:
10✔
2327
            task = cast(asyncio.Task, current_task())
10✔
2328
            task.set_name(get_callable_name(func))
10✔
2329
            _task_states[task] = TaskState(None, None)
10✔
2330

2331
            try:
10✔
2332
                return await func(*args)
10✔
2333
            finally:
2334
                del _task_states[task]
10✔
2335

2336
        debug = options.get("debug", None)
10✔
2337
        loop_factory = options.get("loop_factory", None)
10✔
2338
        if loop_factory is None and options.get("use_uvloop", False):
10✔
2339
            import uvloop
7✔
2340

2341
            loop_factory = uvloop.new_event_loop
7✔
2342

2343
        with Runner(debug=debug, loop_factory=loop_factory) as runner:
10✔
2344
            return runner.run(wrapper())
10✔
2345

2346
    @classmethod
10✔
2347
    def current_token(cls) -> object:
10✔
2348
        return get_running_loop()
10✔
2349

2350
    @classmethod
10✔
2351
    def current_time(cls) -> float:
10✔
2352
        return get_running_loop().time()
10✔
2353

2354
    @classmethod
10✔
2355
    def cancelled_exception_class(cls) -> type[BaseException]:
10✔
2356
        return CancelledError
10✔
2357

2358
    @classmethod
10✔
2359
    async def checkpoint(cls) -> None:
10✔
2360
        await sleep(0)
10✔
2361

2362
    @classmethod
10✔
2363
    async def checkpoint_if_cancelled(cls) -> None:
10✔
2364
        task = current_task()
10✔
2365
        if task is None:
10✔
2366
            return
×
2367

2368
        try:
10✔
2369
            cancel_scope = _task_states[task].cancel_scope
10✔
2370
        except KeyError:
10✔
2371
            return
10✔
2372

2373
        while cancel_scope:
10✔
2374
            if cancel_scope.cancel_called:
10✔
2375
                await sleep(0)
10✔
2376
            elif cancel_scope.shield:
10✔
2377
                break
9✔
2378
            else:
2379
                cancel_scope = cancel_scope._parent_scope
10✔
2380

2381
    @classmethod
10✔
2382
    async def cancel_shielded_checkpoint(cls) -> None:
10✔
2383
        with CancelScope(shield=True):
10✔
2384
            await sleep(0)
10✔
2385

2386
    @classmethod
10✔
2387
    async def sleep(cls, delay: float) -> None:
10✔
2388
        await sleep(delay)
10✔
2389

2390
    @classmethod
10✔
2391
    def create_cancel_scope(
10✔
2392
        cls, *, deadline: float = math.inf, shield: bool = False
2393
    ) -> CancelScope:
2394
        return CancelScope(deadline=deadline, shield=shield)
10✔
2395

2396
    @classmethod
10✔
2397
    def current_effective_deadline(cls) -> float:
10✔
2398
        try:
9✔
2399
            cancel_scope = _task_states[current_task()].cancel_scope
9✔
UNCOV
2400
        except KeyError:
×
2401
            return math.inf
×
2402

2403
        deadline = math.inf
9✔
2404
        while cancel_scope:
9✔
2405
            deadline = min(deadline, cancel_scope.deadline)
9✔
2406
            if cancel_scope._cancel_called:
9✔
2407
                deadline = -math.inf
9✔
2408
                break
9✔
2409
            elif cancel_scope.shield:
9✔
2410
                break
9✔
2411
            else:
2412
                cancel_scope = cancel_scope._parent_scope
9✔
2413

2414
        return deadline
9✔
2415

2416
    @classmethod
10✔
2417
    def create_task_group(cls) -> abc.TaskGroup:
10✔
2418
        return TaskGroup()
10✔
2419

2420
    @classmethod
10✔
2421
    def create_event(cls) -> abc.Event:
10✔
2422
        return Event()
10✔
2423

2424
    @classmethod
10✔
2425
    def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
10✔
2426
        return Lock(fast_acquire=fast_acquire)
9✔
2427

2428
    @classmethod
10✔
2429
    def create_semaphore(
10✔
2430
        cls,
2431
        initial_value: int,
2432
        *,
2433
        max_value: int | None = None,
2434
        fast_acquire: bool = False,
2435
    ) -> abc.Semaphore:
2436
        return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
9✔
2437

2438
    @classmethod
10✔
2439
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
10✔
2440
        return CapacityLimiter(total_tokens)
9✔
2441

2442
    @classmethod
10✔
2443
    async def run_sync_in_worker_thread(
10✔
2444
        cls,
2445
        func: Callable[[Unpack[PosArgsT]], T_Retval],
2446
        args: tuple[Unpack[PosArgsT]],
2447
        abandon_on_cancel: bool = False,
2448
        limiter: abc.CapacityLimiter | None = None,
2449
    ) -> T_Retval:
2450
        await cls.checkpoint()
10✔
2451

2452
        # If this is the first run in this event loop thread, set up the necessary
2453
        # variables
2454
        try:
10✔
2455
            idle_workers = _threadpool_idle_workers.get()
10✔
2456
            workers = _threadpool_workers.get()
10✔
2457
        except LookupError:
10✔
2458
            idle_workers = deque()
10✔
2459
            workers = set()
10✔
2460
            _threadpool_idle_workers.set(idle_workers)
10✔
2461
            _threadpool_workers.set(workers)
10✔
2462

2463
        async with limiter or cls.current_default_thread_limiter():
10✔
2464
            with CancelScope(shield=not abandon_on_cancel) as scope:
10✔
2465
                future: asyncio.Future = asyncio.Future()
10✔
2466
                root_task = find_root_task()
10✔
2467
                if not idle_workers:
10✔
2468
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2469
                    worker.start()
10✔
2470
                    workers.add(worker)
10✔
2471
                    root_task.add_done_callback(worker.stop)
10✔
2472
                else:
2473
                    worker = idle_workers.pop()
10✔
2474

2475
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2476
                    # seconds or longer
2477
                    now = cls.current_time()
10✔
2478
                    while idle_workers:
10✔
2479
                        if (
9✔
2480
                            now - idle_workers[0].idle_since
2481
                            < WorkerThread.MAX_IDLE_TIME
2482
                        ):
2483
                            break
9✔
2484

2485
                        expired_worker = idle_workers.popleft()
×
2486
                        expired_worker.root_task.remove_done_callback(
×
2487
                            expired_worker.stop
2488
                        )
2489
                        expired_worker.stop()
×
2490

2491
                context = copy_context()
10✔
2492
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2493
                if abandon_on_cancel or scope._parent_scope is None:
10✔
2494
                    worker_scope = scope
10✔
2495
                else:
2496
                    worker_scope = scope._parent_scope
10✔
2497

2498
                worker.queue.put_nowait((context, func, args, future, worker_scope))
10✔
2499
                return await future
10✔
2500

2501
    @classmethod
10✔
2502
    def check_cancelled(cls) -> None:
10✔
2503
        scope: CancelScope | None = threadlocals.current_cancel_scope
10✔
2504
        while scope is not None:
10✔
2505
            if scope.cancel_called:
10✔
2506
                raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
10✔
2507

2508
            if scope.shield:
10✔
2509
                return
×
2510

2511
            scope = scope._parent_scope
10✔
2512

2513
    @classmethod
10✔
2514
    def run_async_from_thread(
10✔
2515
        cls,
2516
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
2517
        args: tuple[Unpack[PosArgsT]],
2518
        token: object,
2519
    ) -> T_Retval:
2520
        async def task_wrapper(scope: CancelScope) -> T_Retval:
10✔
2521
            __tracebackhide__ = True
10✔
2522
            task = cast(asyncio.Task, current_task())
10✔
2523
            _task_states[task] = TaskState(None, scope)
10✔
2524
            scope._tasks.add(task)
10✔
2525
            try:
10✔
2526
                return await func(*args)
10✔
2527
            except CancelledError as exc:
10✔
2528
                raise concurrent.futures.CancelledError(str(exc)) from None
10✔
2529
            finally:
2530
                scope._tasks.discard(task)
10✔
2531

2532
        loop = cast(AbstractEventLoop, token)
10✔
2533
        context = copy_context()
10✔
2534
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
10✔
2535
        wrapper = task_wrapper(threadlocals.current_cancel_scope)
10✔
2536
        f: concurrent.futures.Future[T_Retval] = context.run(
10✔
2537
            asyncio.run_coroutine_threadsafe, wrapper, loop
2538
        )
2539
        return f.result()
10✔
2540

2541
    @classmethod
10✔
2542
    def run_sync_from_thread(
10✔
2543
        cls,
2544
        func: Callable[[Unpack[PosArgsT]], T_Retval],
2545
        args: tuple[Unpack[PosArgsT]],
2546
        token: object,
2547
    ) -> T_Retval:
2548
        @wraps(func)
10✔
2549
        def wrapper() -> None:
10✔
2550
            try:
10✔
2551
                sniffio.current_async_library_cvar.set("asyncio")
10✔
2552
                f.set_result(func(*args))
10✔
2553
            except BaseException as exc:
10✔
2554
                f.set_exception(exc)
10✔
2555
                if not isinstance(exc, Exception):
10✔
2556
                    raise
×
2557

2558
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2559
        loop = cast(AbstractEventLoop, token)
10✔
2560
        loop.call_soon_threadsafe(wrapper)
10✔
2561
        return f.result()
10✔
2562

2563
    @classmethod
10✔
2564
    def create_blocking_portal(cls) -> abc.BlockingPortal:
10✔
2565
        return BlockingPortal()
10✔
2566

2567
    @classmethod
10✔
2568
    async def open_process(
10✔
2569
        cls,
2570
        command: StrOrBytesPath | Sequence[StrOrBytesPath],
2571
        *,
2572
        stdin: int | IO[Any] | None,
2573
        stdout: int | IO[Any] | None,
2574
        stderr: int | IO[Any] | None,
2575
        **kwargs: Any,
2576
    ) -> Process:
2577
        await cls.checkpoint()
9✔
2578
        if isinstance(command, PathLike):
9✔
2579
            command = os.fspath(command)
×
2580

2581
        if isinstance(command, (str, bytes)):
9✔
2582
            process = await asyncio.create_subprocess_shell(
9✔
2583
                command,
2584
                stdin=stdin,
2585
                stdout=stdout,
2586
                stderr=stderr,
2587
                **kwargs,
2588
            )
2589
        else:
2590
            process = await asyncio.create_subprocess_exec(
9✔
2591
                *command,
2592
                stdin=stdin,
2593
                stdout=stdout,
2594
                stderr=stderr,
2595
                **kwargs,
2596
            )
2597

2598
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
9✔
2599
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
9✔
2600
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
9✔
2601
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
9✔
2602

2603
    @classmethod
10✔
2604
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
10✔
2605
        create_task(
9✔
2606
            _shutdown_process_pool_on_exit(workers),
2607
            name="AnyIO process pool shutdown task",
2608
        )
2609
        find_root_task().add_done_callback(
9✔
2610
            partial(_forcibly_shutdown_process_pool_on_exit, workers)  # type:ignore[arg-type]
2611
        )
2612

2613
    @classmethod
10✔
2614
    async def connect_tcp(
10✔
2615
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2616
    ) -> abc.SocketStream:
2617
        transport, protocol = cast(
10✔
2618
            tuple[asyncio.Transport, StreamProtocol],
2619
            await get_running_loop().create_connection(
2620
                StreamProtocol, host, port, local_addr=local_address
2621
            ),
2622
        )
2623
        transport.pause_reading()
10✔
2624
        return SocketStream(transport, protocol)
10✔
2625

2626
    @classmethod
10✔
2627
    async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
10✔
2628
        await cls.checkpoint()
7✔
2629
        loop = get_running_loop()
7✔
2630
        raw_socket = socket.socket(socket.AF_UNIX)
7✔
2631
        raw_socket.setblocking(False)
7✔
2632
        while True:
5✔
2633
            try:
7✔
2634
                raw_socket.connect(path)
7✔
2635
            except BlockingIOError:
7✔
2636
                f: asyncio.Future = asyncio.Future()
×
2637
                loop.add_writer(raw_socket, f.set_result, None)
×
2638
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2639
                await f
×
2640
            except BaseException:
7✔
2641
                raw_socket.close()
7✔
2642
                raise
7✔
2643
            else:
2644
                return UNIXSocketStream(raw_socket)
7✔
2645

2646
    @classmethod
10✔
2647
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2648
        return TCPSocketListener(sock)
10✔
2649

2650
    @classmethod
10✔
2651
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2652
        return UNIXSocketListener(sock)
7✔
2653

2654
    @classmethod
10✔
2655
    async def create_udp_socket(
10✔
2656
        cls,
2657
        family: AddressFamily,
2658
        local_address: IPSockAddrType | None,
2659
        remote_address: IPSockAddrType | None,
2660
        reuse_port: bool,
2661
    ) -> UDPSocket | ConnectedUDPSocket:
2662
        transport, protocol = await get_running_loop().create_datagram_endpoint(
9✔
2663
            DatagramProtocol,
2664
            local_addr=local_address,
2665
            remote_addr=remote_address,
2666
            family=family,
2667
            reuse_port=reuse_port,
2668
        )
2669
        if protocol.exception:
9✔
2670
            transport.close()
×
2671
            raise protocol.exception
×
2672

2673
        if not remote_address:
9✔
2674
            return UDPSocket(transport, protocol)
9✔
2675
        else:
2676
            return ConnectedUDPSocket(transport, protocol)
9✔
2677

2678
    @classmethod
10✔
2679
    async def create_unix_datagram_socket(  # type: ignore[override]
10✔
2680
        cls, raw_socket: socket.socket, remote_path: str | bytes | None
2681
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2682
        await cls.checkpoint()
7✔
2683
        loop = get_running_loop()
7✔
2684

2685
        if remote_path:
7✔
2686
            while True:
5✔
2687
                try:
7✔
2688
                    raw_socket.connect(remote_path)
7✔
2689
                except BlockingIOError:
×
2690
                    f: asyncio.Future = asyncio.Future()
×
2691
                    loop.add_writer(raw_socket, f.set_result, None)
×
2692
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2693
                    await f
×
2694
                except BaseException:
×
2695
                    raw_socket.close()
×
2696
                    raise
×
2697
                else:
2698
                    return ConnectedUNIXDatagramSocket(raw_socket)
7✔
2699
        else:
2700
            return UNIXDatagramSocket(raw_socket)
7✔
2701

2702
    @classmethod
10✔
2703
    async def getaddrinfo(
10✔
2704
        cls,
2705
        host: bytes | str | None,
2706
        port: str | int | None,
2707
        *,
2708
        family: int | AddressFamily = 0,
2709
        type: int | SocketKind = 0,
2710
        proto: int = 0,
2711
        flags: int = 0,
2712
    ) -> list[
2713
        tuple[
2714
            AddressFamily,
2715
            SocketKind,
2716
            int,
2717
            str,
2718
            tuple[str, int] | tuple[str, int, int, int],
2719
        ]
2720
    ]:
2721
        return await get_running_loop().getaddrinfo(
10✔
2722
            host, port, family=family, type=type, proto=proto, flags=flags
2723
        )
2724

2725
    @classmethod
10✔
2726
    async def getnameinfo(
10✔
2727
        cls, sockaddr: IPSockAddrType, flags: int = 0
2728
    ) -> tuple[str, str]:
2729
        return await get_running_loop().getnameinfo(sockaddr, flags)
9✔
2730

2731
    @classmethod
10✔
2732
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
10✔
2733
        await cls.checkpoint()
×
2734
        try:
×
2735
            read_events = _read_events.get()
×
2736
        except LookupError:
×
2737
            read_events = {}
×
2738
            _read_events.set(read_events)
×
2739

2740
        if read_events.get(sock):
×
2741
            raise BusyResourceError("reading from") from None
×
2742

2743
        loop = get_running_loop()
×
2744
        event = read_events[sock] = asyncio.Event()
×
2745
        loop.add_reader(sock, event.set)
×
2746
        try:
×
2747
            await event.wait()
×
2748
        finally:
2749
            if read_events.pop(sock, None) is not None:
×
2750
                loop.remove_reader(sock)
×
2751
                readable = True
×
2752
            else:
2753
                readable = False
×
2754

2755
        if not readable:
×
2756
            raise ClosedResourceError
×
2757

2758
    @classmethod
10✔
2759
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
10✔
2760
        await cls.checkpoint()
×
2761
        try:
×
2762
            write_events = _write_events.get()
×
2763
        except LookupError:
×
2764
            write_events = {}
×
2765
            _write_events.set(write_events)
×
2766

2767
        if write_events.get(sock):
×
2768
            raise BusyResourceError("writing to") from None
×
2769

2770
        loop = get_running_loop()
×
2771
        event = write_events[sock] = asyncio.Event()
×
2772
        loop.add_writer(sock.fileno(), event.set)
×
2773
        try:
×
2774
            await event.wait()
×
2775
        finally:
2776
            if write_events.pop(sock, None) is not None:
×
2777
                loop.remove_writer(sock)
×
2778
                writable = True
×
2779
            else:
2780
                writable = False
×
2781

2782
        if not writable:
×
2783
            raise ClosedResourceError
×
2784

2785
    @classmethod
10✔
2786
    def current_default_thread_limiter(cls) -> CapacityLimiter:
10✔
2787
        try:
10✔
2788
            return _default_thread_limiter.get()
10✔
2789
        except LookupError:
10✔
2790
            limiter = CapacityLimiter(40)
10✔
2791
            _default_thread_limiter.set(limiter)
10✔
2792
            return limiter
10✔
2793

2794
    @classmethod
10✔
2795
    def open_signal_receiver(
10✔
2796
        cls, *signals: Signals
2797
    ) -> AbstractContextManager[AsyncIterator[Signals]]:
2798
        return _SignalReceiver(signals)
8✔
2799

2800
    @classmethod
10✔
2801
    def get_current_task(cls) -> TaskInfo:
10✔
2802
        return AsyncIOTaskInfo(current_task())  # type: ignore[arg-type]
10✔
2803

2804
    @classmethod
10✔
2805
    def get_running_tasks(cls) -> Sequence[TaskInfo]:
10✔
2806
        return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
10✔
2807

2808
    @classmethod
10✔
2809
    async def wait_all_tasks_blocked(cls) -> None:
10✔
2810
        await cls.checkpoint()
10✔
2811
        this_task = current_task()
10✔
2812
        while True:
7✔
2813
            for task in all_tasks():
10✔
2814
                if task is this_task:
10✔
2815
                    continue
10✔
2816

2817
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2818
                if waiter is None or waiter.done():
10✔
2819
                    await sleep(0.1)
10✔
2820
                    break
10✔
2821
            else:
2822
                return
10✔
2823

2824
    @classmethod
10✔
2825
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
10✔
2826
        return TestRunner(**options)
10✔
2827

2828

2829
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc