• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 11765112365

10 Nov 2024 12:27PM UTC coverage: 91.444% (+0.004%) from 91.44%
11765112365

Pull #822

github

web-flow
Merge 1807bede3 into bdf09a6ed
Pull Request #822: Added support for asyncio eager task factories

44 of 46 new or added lines in 2 files covered. (95.65%)

125 existing lines in 1 file now uncovered.

4852 of 5306 relevant lines covered (91.44%)

8.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.81
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import os
10✔
8
import socket
10✔
9
import sys
10✔
10
import threading
10✔
11
import weakref
10✔
12
from asyncio import (
10✔
13
    AbstractEventLoop,
14
    CancelledError,
15
    all_tasks,
16
    create_task,
17
    current_task,
18
    get_running_loop,
19
    sleep,
20
)
21
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
22
from collections import OrderedDict, deque
10✔
23
from collections.abc import (
10✔
24
    AsyncGenerator,
25
    AsyncIterator,
26
    Awaitable,
27
    Callable,
28
    Collection,
29
    Coroutine,
30
    Iterable,
31
    Sequence,
32
)
33
from concurrent.futures import Future
10✔
34
from contextlib import AbstractContextManager, suppress
10✔
35
from contextvars import Context, copy_context
10✔
36
from dataclasses import dataclass
10✔
37
from functools import partial, wraps
10✔
38
from inspect import (
10✔
39
    CORO_RUNNING,
40
    CORO_SUSPENDED,
41
    getcoroutinestate,
42
    iscoroutine,
43
)
44
from io import IOBase
10✔
45
from os import PathLike
10✔
46
from queue import Queue
10✔
47
from signal import Signals
10✔
48
from socket import AddressFamily, SocketKind
10✔
49
from threading import Thread
10✔
50
from types import TracebackType
10✔
51
from typing import (
10✔
52
    IO,
53
    Any,
54
    Optional,
55
    TypeVar,
56
    cast,
57
)
58
from weakref import WeakKeyDictionary
10✔
59

60
import sniffio
10✔
61

62
from .. import (
10✔
63
    CapacityLimiterStatistics,
64
    EventStatistics,
65
    LockStatistics,
66
    TaskInfo,
67
    abc,
68
)
69
from .._core._eventloop import claim_worker_thread, threadlocals
10✔
70
from .._core._exceptions import (
10✔
71
    BrokenResourceError,
72
    BusyResourceError,
73
    ClosedResourceError,
74
    EndOfStream,
75
    WouldBlock,
76
    iterate_exceptions,
77
)
78
from .._core._sockets import convert_ipv6_sockaddr
10✔
79
from .._core._streams import create_memory_object_stream
10✔
80
from .._core._synchronization import (
10✔
81
    CapacityLimiter as BaseCapacityLimiter,
82
)
83
from .._core._synchronization import Event as BaseEvent
10✔
84
from .._core._synchronization import Lock as BaseLock
10✔
85
from .._core._synchronization import (
10✔
86
    ResourceGuard,
87
    SemaphoreStatistics,
88
)
89
from .._core._synchronization import Semaphore as BaseSemaphore
10✔
90
from .._core._tasks import CancelScope as BaseCancelScope
10✔
91
from ..abc import (
10✔
92
    AsyncBackend,
93
    IPSockAddrType,
94
    SocketListener,
95
    UDPPacketType,
96
    UNIXDatagramPacketType,
97
)
98
from ..abc._eventloop import StrOrBytesPath
10✔
99
from ..lowlevel import RunVar
10✔
100
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
101

102
if sys.version_info >= (3, 10):
10✔
103
    from typing import ParamSpec
7✔
104
else:
105
    from typing_extensions import ParamSpec
3✔
106

107
if sys.version_info >= (3, 11):
10✔
108
    from asyncio import Runner
5✔
109
    from typing import TypeVarTuple, Unpack
5✔
110
else:
111
    import contextvars
5✔
112
    import enum
5✔
113
    import signal
5✔
114
    from asyncio import coroutines, events, exceptions, tasks
5✔
115

116
    from exceptiongroup import BaseExceptionGroup
5✔
117
    from typing_extensions import TypeVarTuple, Unpack
5✔
118

119
    class _State(enum.Enum):
5✔
120
        CREATED = "created"
5✔
121
        INITIALIZED = "initialized"
5✔
122
        CLOSED = "closed"
5✔
123

124
    class Runner:
5✔
125
        # Copied from CPython 3.11
126
        def __init__(
5✔
127
            self,
128
            *,
129
            debug: bool | None = None,
130
            loop_factory: Callable[[], AbstractEventLoop] | None = None,
131
        ):
132
            self._state = _State.CREATED
5✔
133
            self._debug = debug
5✔
134
            self._loop_factory = loop_factory
5✔
135
            self._loop: AbstractEventLoop | None = None
5✔
136
            self._context = None
5✔
137
            self._interrupt_count = 0
5✔
138
            self._set_event_loop = False
5✔
139

140
        def __enter__(self) -> Runner:
5✔
141
            self._lazy_init()
5✔
142
            return self
5✔
143

144
        def __exit__(
5✔
145
            self,
146
            exc_type: type[BaseException],
147
            exc_val: BaseException,
148
            exc_tb: TracebackType,
149
        ) -> None:
150
            self.close()
5✔
151

152
        def close(self) -> None:
5✔
153
            """Shutdown and close event loop."""
154
            if self._state is not _State.INITIALIZED:
5✔
UNCOV
155
                return
×
156
            try:
5✔
157
                loop = self._loop
5✔
158
                _cancel_all_tasks(loop)
5✔
159
                loop.run_until_complete(loop.shutdown_asyncgens())
5✔
160
                if hasattr(loop, "shutdown_default_executor"):
5✔
161
                    loop.run_until_complete(loop.shutdown_default_executor())
5✔
162
                else:
UNCOV
163
                    loop.run_until_complete(_shutdown_default_executor(loop))
×
164
            finally:
165
                if self._set_event_loop:
5✔
166
                    events.set_event_loop(None)
5✔
167
                loop.close()
5✔
168
                self._loop = None
5✔
169
                self._state = _State.CLOSED
5✔
170

171
        def get_loop(self) -> AbstractEventLoop:
5✔
172
            """Return embedded event loop."""
173
            self._lazy_init()
5✔
174
            return self._loop
5✔
175

176
        def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
5✔
177
            """Run a coroutine inside the embedded event loop."""
178
            if not coroutines.iscoroutine(coro):
5✔
UNCOV
179
                raise ValueError(f"a coroutine was expected, got {coro!r}")
×
180

181
            if events._get_running_loop() is not None:
5✔
182
                # fail fast with short traceback
UNCOV
183
                raise RuntimeError(
×
184
                    "Runner.run() cannot be called from a running event loop"
185
                )
186

187
            self._lazy_init()
5✔
188

189
            if context is None:
5✔
190
                context = self._context
5✔
191
            task = context.run(self._loop.create_task, coro)
5✔
192

193
            if (
5✔
194
                threading.current_thread() is threading.main_thread()
195
                and signal.getsignal(signal.SIGINT) is signal.default_int_handler
196
            ):
197
                sigint_handler = partial(self._on_sigint, main_task=task)
5✔
198
                try:
5✔
199
                    signal.signal(signal.SIGINT, sigint_handler)
5✔
UNCOV
200
                except ValueError:
×
201
                    # `signal.signal` may throw if `threading.main_thread` does
202
                    # not support signals (e.g. embedded interpreter with signals
203
                    # not registered - see gh-91880)
UNCOV
204
                    sigint_handler = None
×
205
            else:
206
                sigint_handler = None
5✔
207

208
            self._interrupt_count = 0
5✔
209
            try:
5✔
210
                return self._loop.run_until_complete(task)
5✔
211
            except exceptions.CancelledError:
5✔
UNCOV
212
                if self._interrupt_count > 0:
×
UNCOV
213
                    uncancel = getattr(task, "uncancel", None)
×
214
                    if uncancel is not None and uncancel() == 0:
×
215
                        raise KeyboardInterrupt()
×
216
                raise  # CancelledError
×
217
            finally:
218
                if (
5✔
219
                    sigint_handler is not None
220
                    and signal.getsignal(signal.SIGINT) is sigint_handler
221
                ):
222
                    signal.signal(signal.SIGINT, signal.default_int_handler)
5✔
223

224
        def _lazy_init(self) -> None:
6✔
225
            if self._state is _State.CLOSED:
5✔
UNCOV
226
                raise RuntimeError("Runner is closed")
×
227
            if self._state is _State.INITIALIZED:
5✔
228
                return
5✔
229
            if self._loop_factory is None:
5✔
230
                self._loop = events.new_event_loop()
5✔
231
                if not self._set_event_loop:
5✔
232
                    # Call set_event_loop only once to avoid calling
233
                    # attach_loop multiple times on child watchers
234
                    events.set_event_loop(self._loop)
6✔
235
                    self._set_event_loop = True
6✔
236
            else:
237
                self._loop = self._loop_factory()
4✔
238
            if self._debug is not None:
6✔
239
                self._loop.set_debug(self._debug)
6✔
240
            self._context = contextvars.copy_context()
5✔
241
            self._state = _State.INITIALIZED
5✔
242

243
        def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
5✔
UNCOV
244
            self._interrupt_count += 1
×
245
            if self._interrupt_count == 1 and not main_task.done():
1✔
246
                main_task.cancel()
×
247
                # wakeup loop if it is blocked by select() with long timeout
248
                self._loop.call_soon_threadsafe(lambda: None)
×
UNCOV
249
                return
×
250
            raise KeyboardInterrupt()
×
251

252
    def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
5✔
253
        to_cancel = tasks.all_tasks(loop)
5✔
254
        if not to_cancel:
5✔
255
            return
5✔
256

257
        for task in to_cancel:
5✔
258
            task.cancel()
5✔
259

260
        loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
5✔
261

262
        for task in to_cancel:
5✔
263
            if task.cancelled():
5✔
264
                continue
5✔
265
            if task.exception() is not None:
4✔
UNCOV
266
                loop.call_exception_handler(
×
267
                    {
268
                        "message": "unhandled exception during asyncio.run() shutdown",
269
                        "exception": task.exception(),
270
                        "task": task,
271
                    }
272
                )
273

274
    async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
5✔
275
        """Schedule the shutdown of the default executor."""
276

UNCOV
277
        def _do_shutdown(future: asyncio.futures.Future) -> None:
×
UNCOV
278
            try:
×
279
                loop._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
×
280
                loop.call_soon_threadsafe(future.set_result, None)
×
281
            except Exception as ex:
×
282
                loop.call_soon_threadsafe(future.set_exception, ex)
×
283

284
        loop._executor_shutdown_called = True
×
UNCOV
285
        if loop._default_executor is None:
×
286
            return
×
287
        future = loop.create_future()
×
288
        thread = threading.Thread(target=_do_shutdown, args=(future,))
×
289
        thread.start()
×
290
        try:
×
291
            await future
×
292
        finally:
293
            thread.join()
×
294

295

296
T_Retval = TypeVar("T_Retval")
10✔
297
T_contra = TypeVar("T_contra", contravariant=True)
10✔
298
PosArgsT = TypeVarTuple("PosArgsT")
10✔
299
P = ParamSpec("P")
10✔
300

301
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
302

303

304
def find_root_task() -> asyncio.Task:
10✔
305
    root_task = _root_task.get(None)
10✔
306
    if root_task is not None and not root_task.done():
10✔
307
        return root_task
10✔
308

309
    # Look for a task that has been started via run_until_complete()
310
    for task in all_tasks():
10✔
311
        if task._callbacks and not task.done():
10✔
312
            callbacks = [cb for cb, context in task._callbacks]
10✔
313
            for cb in callbacks:
10✔
314
                if (
10✔
315
                    cb is _run_until_complete_cb
316
                    or getattr(cb, "__module__", None) == "uvloop.loop"
317
                ):
318
                    _root_task.set(task)
10✔
319
                    return task
10✔
320

321
    # Look up the topmost task in the AnyIO task tree, if possible
322
    task = cast(asyncio.Task, current_task())
9✔
323
    state = _task_states.get(task)
9✔
324
    if state:
9✔
325
        cancel_scope = state.cancel_scope
9✔
326
        while cancel_scope and cancel_scope._parent_scope is not None:
9✔
UNCOV
327
            cancel_scope = cancel_scope._parent_scope
×
328

329
        if cancel_scope is not None:
9✔
330
            return cast(asyncio.Task, cancel_scope._host_task)
9✔
331

UNCOV
332
    return task
×
333

334

335
def get_callable_name(func: Callable) -> str:
10✔
336
    module = getattr(func, "__module__", None)
10✔
337
    qualname = getattr(func, "__qualname__", None)
10✔
338
    return ".".join([x for x in (module, qualname) if x])
10✔
339

340

341
#
342
# Event loop
343
#
344

345
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
10✔
346

347

348
def _task_started(task: asyncio.Task) -> bool:
10✔
349
    """Return ``True`` if the task has been started and has not finished."""
350
    # The task coro should never be None here, as we never add finished tasks to the
351
    # task list
352
    coro = task.get_coro()
10✔
353
    assert coro is not None
10✔
354
    try:
10✔
355
        return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
NEW
356
    except AttributeError:
×
357
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
NEW
358
        raise Exception(f"Cannot determine if task {task} has started or not") from None
×
359

360

361
#
362
# Timeouts and cancellation
363
#
364

365

366
def is_anyio_cancellation(exc: CancelledError) -> bool:
10✔
367
    return (
10✔
368
        bool(exc.args)
369
        and isinstance(exc.args[0], str)
370
        and exc.args[0].startswith("Cancelled by cancel scope ")
371
    )
372

373

374
class CancelScope(BaseCancelScope):
10✔
375
    def __new__(
10✔
376
        cls, *, deadline: float = math.inf, shield: bool = False
377
    ) -> CancelScope:
378
        return object.__new__(cls)
10✔
379

380
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
381
        self._deadline = deadline
10✔
382
        self._shield = shield
10✔
383
        self._parent_scope: CancelScope | None = None
10✔
384
        self._child_scopes: set[CancelScope] = set()
10✔
385
        self._cancel_called = False
10✔
386
        self._cancelled_caught = False
10✔
387
        self._active = False
10✔
388
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
389
        self._cancel_handle: asyncio.Handle | None = None
10✔
390
        self._tasks: set[asyncio.Task] = set()
10✔
391
        self._host_task: asyncio.Task | None = None
10✔
392
        self._cancel_calls: int = 0
10✔
393
        self._cancelling: int | None = None
10✔
394

395
    def __enter__(self) -> CancelScope:
10✔
396
        if self._active:
10✔
UNCOV
397
            raise RuntimeError(
×
398
                "Each CancelScope may only be used for a single 'with' block"
399
            )
400

401
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
402
        self._tasks.add(host_task)
10✔
403
        try:
10✔
404
            task_state = _task_states[host_task]
10✔
405
        except KeyError:
10✔
406
            task_state = TaskState(None, self)
10✔
407
            _task_states[host_task] = task_state
10✔
408
        else:
409
            self._parent_scope = task_state.cancel_scope
10✔
410
            task_state.cancel_scope = self
10✔
411
            if self._parent_scope is not None:
10✔
412
                self._parent_scope._child_scopes.add(self)
10✔
413
                self._parent_scope._tasks.remove(host_task)
10✔
414

415
        self._timeout()
10✔
416
        self._active = True
10✔
417
        if sys.version_info >= (3, 11):
10✔
418
            self._cancelling = self._host_task.cancelling()
5✔
419

420
        # Start cancelling the host task if the scope was cancelled before entering
421
        if self._cancel_called:
10✔
422
            self._deliver_cancellation(self)
10✔
423

424
        return self
10✔
425

426
    def __exit__(
10✔
427
        self,
428
        exc_type: type[BaseException] | None,
429
        exc_val: BaseException | None,
430
        exc_tb: TracebackType | None,
431
    ) -> bool | None:
432
        del exc_tb
10✔
433

434
        if not self._active:
10✔
435
            raise RuntimeError("This cancel scope is not active")
9✔
436
        if current_task() is not self._host_task:
10✔
437
            raise RuntimeError(
9✔
438
                "Attempted to exit cancel scope in a different task than it was "
439
                "entered in"
440
            )
441

442
        assert self._host_task is not None
10✔
443
        host_task_state = _task_states.get(self._host_task)
10✔
444
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
445
            raise RuntimeError(
9✔
446
                "Attempted to exit a cancel scope that isn't the current tasks's "
447
                "current cancel scope"
448
            )
449

450
        try:
10✔
451
            self._active = False
10✔
452
            if self._timeout_handle:
10✔
453
                self._timeout_handle.cancel()
10✔
454
                self._timeout_handle = None
10✔
455

456
            self._tasks.remove(self._host_task)
10✔
457
            if self._parent_scope is not None:
10✔
458
                self._parent_scope._child_scopes.remove(self)
10✔
459
                self._parent_scope._tasks.add(self._host_task)
10✔
460

461
            host_task_state.cancel_scope = self._parent_scope
10✔
462

463
            # Undo all cancellations done by this scope
464
            if self._cancelling is not None:
10✔
465
                while self._cancel_calls:
5✔
466
                    self._cancel_calls -= 1
5✔
467
                    if self._host_task.uncancel() <= self._cancelling:
5✔
468
                        break
5✔
469

470
            # We only swallow the exception iff it was an AnyIO CancelledError, either
471
            # directly as exc_val or inside an exception group and there are no cancelled
472
            # parent cancel scopes visible to us here
473
            not_swallowed_exceptions = 0
10✔
474
            swallow_exception = False
10✔
475
            if exc_val is not None:
10✔
476
                for exc in iterate_exceptions(exc_val):
10✔
477
                    if self._cancel_called and isinstance(exc, CancelledError):
10✔
478
                        if not (swallow_exception := self._uncancel(exc)):
10✔
479
                            not_swallowed_exceptions += 1
10✔
480
                    else:
481
                        not_swallowed_exceptions += 1
10✔
482

483
            # Restart the cancellation effort in the closest visible, cancelled parent
484
            # scope if necessary
485
            self._restart_cancellation_in_parent()
10✔
486
            return swallow_exception and not not_swallowed_exceptions
10✔
487
        finally:
488
            self._host_task = None
10✔
489
            del exc_val
10✔
490

491
    @property
10✔
492
    def _effectively_cancelled(self) -> bool:
10✔
493
        cancel_scope: CancelScope | None = self
10✔
494
        while cancel_scope is not None:
10✔
495
            if cancel_scope._cancel_called:
10✔
496
                return True
10✔
497

498
            if cancel_scope.shield:
10✔
499
                return False
9✔
500

501
            cancel_scope = cancel_scope._parent_scope
10✔
502

503
        return False
10✔
504

505
    @property
10✔
506
    def _parent_cancellation_is_visible_to_us(self) -> bool:
10✔
507
        return (
10✔
508
            self._parent_scope is not None
509
            and not self.shield
510
            and self._parent_scope._effectively_cancelled
511
        )
512

513
    def _uncancel(self, cancelled_exc: CancelledError) -> bool:
10✔
514
        if self._host_task is None:
10✔
UNCOV
515
            self._cancel_calls = 0
×
UNCOV
516
            return True
×
517

518
        while True:
7✔
519
            if is_anyio_cancellation(cancelled_exc):
10✔
520
                # Only swallow the cancellation exception if it's an AnyIO cancel
521
                # exception and there are no other cancel scopes down the line pending
522
                # cancellation
523
                self._cancelled_caught = (
10✔
524
                    self._effectively_cancelled
525
                    and not self._parent_cancellation_is_visible_to_us
526
                )
527
                return self._cancelled_caught
10✔
528

529
            # Sometimes third party frameworks catch a CancelledError and raise a new
530
            # one, so as a workaround we have to look at the previous ones in
531
            # __context__ too for a matching cancel message
532
            if isinstance(cancelled_exc.__context__, CancelledError):
9✔
533
                cancelled_exc = cancelled_exc.__context__
5✔
534
                continue
5✔
535

536
            return False
9✔
537

538
    def _timeout(self) -> None:
10✔
539
        if self._deadline != math.inf:
10✔
540
            loop = get_running_loop()
10✔
541
            if loop.time() >= self._deadline:
10✔
542
                self.cancel()
10✔
543
            else:
544
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
545

546
    def _deliver_cancellation(self, origin: CancelScope) -> bool:
10✔
547
        """
548
        Deliver cancellation to directly contained tasks and nested cancel scopes.
549

550
        Schedule another run at the end if we still have tasks eligible for
551
        cancellation.
552

553
        :param origin: the cancel scope that originated the cancellation
554
        :return: ``True`` if the delivery needs to be retried on the next cycle
555

556
        """
557
        should_retry = False
10✔
558
        current = current_task()
10✔
559
        for task in self._tasks:
10✔
560
            should_retry = True
10✔
561
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
562
                continue
9✔
563

564
            # The task is eligible for cancellation if it has started
565
            if task is not current and (task is self._host_task or _task_started(task)):
10✔
566
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
567
                if not isinstance(waiter, asyncio.Future) or not waiter.done():
10✔
568
                    task.cancel(f"Cancelled by cancel scope {id(origin):x}")
10✔
569
                    if task is origin._host_task:
10✔
570
                        origin._cancel_calls += 1
10✔
571

572
        # Deliver cancellation to child scopes that aren't shielded or running their own
573
        # cancellation callbacks
574
        for scope in self._child_scopes:
10✔
575
            if not scope._shield and not scope.cancel_called:
10✔
576
                should_retry = scope._deliver_cancellation(origin) or should_retry
10✔
577

578
        # Schedule another callback if there are still tasks left
579
        if origin is self:
10✔
580
            if should_retry:
10✔
581
                self._cancel_handle = get_running_loop().call_soon(
10✔
582
                    self._deliver_cancellation, origin
583
                )
584
            else:
585
                self._cancel_handle = None
10✔
586

587
        return should_retry
10✔
588

589
    def _restart_cancellation_in_parent(self) -> None:
10✔
590
        """
591
        Restart the cancellation effort in the closest directly cancelled parent scope.
592

593
        """
594
        scope = self._parent_scope
10✔
595
        while scope is not None:
10✔
596
            if scope._cancel_called:
10✔
597
                if scope._cancel_handle is None:
10✔
598
                    scope._deliver_cancellation(scope)
10✔
599

600
                break
10✔
601

602
            # No point in looking beyond any shielded scope
603
            if scope._shield:
10✔
604
                break
9✔
605

606
            scope = scope._parent_scope
10✔
607

608
    def cancel(self) -> None:
10✔
609
        if not self._cancel_called:
10✔
610
            if self._timeout_handle:
10✔
611
                self._timeout_handle.cancel()
10✔
612
                self._timeout_handle = None
10✔
613

614
            self._cancel_called = True
10✔
615
            if self._host_task is not None:
10✔
616
                self._deliver_cancellation(self)
10✔
617

618
    @property
10✔
619
    def deadline(self) -> float:
10✔
620
        return self._deadline
9✔
621

622
    @deadline.setter
10✔
623
    def deadline(self, value: float) -> None:
10✔
624
        self._deadline = float(value)
9✔
625
        if self._timeout_handle is not None:
9✔
626
            self._timeout_handle.cancel()
9✔
627
            self._timeout_handle = None
9✔
628

629
        if self._active and not self._cancel_called:
9✔
630
            self._timeout()
9✔
631

632
    @property
10✔
633
    def cancel_called(self) -> bool:
10✔
634
        return self._cancel_called
10✔
635

636
    @property
10✔
637
    def cancelled_caught(self) -> bool:
10✔
638
        return self._cancelled_caught
10✔
639

640
    @property
10✔
641
    def shield(self) -> bool:
10✔
642
        return self._shield
10✔
643

644
    @shield.setter
10✔
645
    def shield(self, value: bool) -> None:
10✔
646
        if self._shield != value:
10✔
647
            self._shield = value
10✔
648
            if not value:
10✔
649
                self._restart_cancellation_in_parent()
9✔
650

651

652
#
653
# Task states
654
#
655

656

657
class TaskState:
10✔
658
    """
659
    Encapsulates auxiliary task information that cannot be added to the Task instance
660
    itself because there are no guarantees about its implementation.
661
    """
662

663
    __slots__ = "parent_id", "cancel_scope", "__weakref__"
10✔
664

665
    def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
10✔
666
        self.parent_id = parent_id
10✔
667
        self.cancel_scope = cancel_scope
10✔
668

669

670
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
10✔
671

672

673
#
674
# Task groups
675
#
676

677

678
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
679
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
680
        self._future = future
10✔
681
        self._parent_id = parent_id
10✔
682

683
    def started(self, value: T_contra | None = None) -> None:
10✔
684
        try:
10✔
685
            self._future.set_result(value)
10✔
686
        except asyncio.InvalidStateError:
9✔
687
            if not self._future.cancelled():
9✔
688
                raise RuntimeError(
9✔
689
                    "called 'started' twice on the same task status"
690
                ) from None
691

692
        task = cast(asyncio.Task, current_task())
10✔
693
        _task_states[task].parent_id = self._parent_id
10✔
694

695

696
async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None:
10✔
697
    tasks = set(tasks)
10✔
698
    waiter = get_running_loop().create_future()
10✔
699

700
    def on_completion(task: asyncio.Task[object]) -> None:
10✔
701
        tasks.discard(task)
10✔
702
        if not tasks and not waiter.done():
10✔
703
            waiter.set_result(None)
10✔
704

705
    for task in tasks:
10✔
706
        task.add_done_callback(on_completion)
10✔
707
        del task
10✔
708

709
    try:
10✔
710
        await waiter
10✔
711
    finally:
712
        while tasks:
10✔
713
            tasks.pop().remove_done_callback(on_completion)
10✔
714

715

716
class TaskGroup(abc.TaskGroup):
10✔
717
    def __init__(self) -> None:
10✔
718
        self.cancel_scope: CancelScope = CancelScope()
10✔
719
        self._active = False
10✔
720
        self._exceptions: list[BaseException] = []
10✔
721
        self._tasks: set[asyncio.Task] = set()
10✔
722

723
    async def __aenter__(self) -> TaskGroup:
10✔
724
        self.cancel_scope.__enter__()
10✔
725
        self._active = True
10✔
726
        return self
10✔
727

728
    async def __aexit__(
10✔
729
        self,
730
        exc_type: type[BaseException] | None,
731
        exc_val: BaseException | None,
732
        exc_tb: TracebackType | None,
733
    ) -> bool | None:
734
        try:
10✔
735
            if exc_val is not None:
10✔
736
                self.cancel_scope.cancel()
10✔
737
                if not isinstance(exc_val, CancelledError):
10✔
738
                    self._exceptions.append(exc_val)
10✔
739

740
            try:
10✔
741
                if self._tasks:
10✔
742
                    with CancelScope() as wait_scope:
10✔
743
                        while self._tasks:
10✔
744
                            try:
10✔
745
                                await _wait(self._tasks)
10✔
746
                            except CancelledError as exc:
10✔
747
                                # Shield the scope against further cancellation attempts,
748
                                # as they're not productive (#695)
749
                                wait_scope.shield = True
10✔
750
                                self.cancel_scope.cancel()
10✔
751

752
                                # Set exc_val from the cancellation exception if it was
753
                                # previously unset. However, we should not replace a native
754
                                # cancellation exception with one raise by a cancel scope.
755
                                if exc_val is None or (
10✔
756
                                    isinstance(exc_val, CancelledError)
757
                                    and not is_anyio_cancellation(exc)
758
                                ):
759
                                    exc_val = exc
10✔
760
                else:
761
                    # If there are no child tasks to wait on, run at least one checkpoint
762
                    # anyway
763
                    await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
764

765
                self._active = False
10✔
766
                if self._exceptions:
10✔
767
                    raise BaseExceptionGroup(
10✔
768
                        "unhandled errors in a TaskGroup", self._exceptions
769
                    )
770
                elif exc_val:
10✔
771
                    raise exc_val
10✔
772
            except BaseException as exc:
10✔
773
                if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
10✔
774
                    return True
10✔
775

776
                raise
10✔
777

778
            return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
779
        finally:
780
            del exc_val, exc_tb, self._exceptions
10✔
781

782
    def _spawn(
10✔
783
        self,
784
        func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
785
        args: tuple[Unpack[PosArgsT]],
786
        name: object,
787
        task_status_future: asyncio.Future | None = None,
788
    ) -> asyncio.Task:
789
        def task_done(_task: asyncio.Task) -> None:
10✔
790
            task_state = _task_states[_task]
10✔
791
            assert task_state.cancel_scope is not None
10✔
792
            assert _task in task_state.cancel_scope._tasks
10✔
793
            task_state.cancel_scope._tasks.remove(_task)
10✔
794
            self._tasks.remove(task)
10✔
795
            del _task_states[_task]
10✔
796

797
            try:
10✔
798
                exc = _task.exception()
10✔
799
            except CancelledError as e:
10✔
800
                while isinstance(e.__context__, CancelledError):
10✔
801
                    e = e.__context__
5✔
802

803
                exc = e
10✔
804

805
            if exc is not None:
10✔
806
                # The future can only be in the cancelled state if the host task was
807
                # cancelled, so return immediately instead of adding one more
808
                # CancelledError to the exceptions list
809
                if task_status_future is not None and task_status_future.cancelled():
10✔
810
                    return
9✔
811

812
                if task_status_future is None or task_status_future.done():
10✔
813
                    if not isinstance(exc, CancelledError):
10✔
814
                        self._exceptions.append(exc)
10✔
815

816
                    if not self.cancel_scope._effectively_cancelled:
10✔
817
                        self.cancel_scope.cancel()
10✔
818
                else:
819
                    task_status_future.set_exception(exc)
9✔
820
            elif task_status_future is not None and not task_status_future.done():
10✔
821
                task_status_future.set_exception(
9✔
822
                    RuntimeError("Child exited without calling task_status.started()")
823
                )
824

825
        if not self._active:
10✔
826
            raise RuntimeError(
9✔
827
                "This task group is not active; no new tasks can be started."
828
            )
829

830
        kwargs = {}
10✔
831
        if task_status_future:
10✔
832
            parent_id = id(current_task())
10✔
833
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
834
                task_status_future, id(self.cancel_scope._host_task)
835
            )
836
        else:
837
            parent_id = id(self.cancel_scope._host_task)
10✔
838

839
        coro = func(*args, **kwargs)
10✔
840
        if not iscoroutine(coro):
10✔
841
            prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
9✔
842
            raise TypeError(
9✔
843
                f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
844
                f"the return value ({coro!r}) is not a coroutine object"
845
            )
846

847
        name = get_callable_name(func) if name is None else str(name)
10✔
848
        task = create_task(coro, name=name)
10✔
849

850
        # Make the spawned task inherit the task group's cancel scope
851
        _task_states[task] = TaskState(
10✔
852
            parent_id=parent_id, cancel_scope=self.cancel_scope
853
        )
854
        self.cancel_scope._tasks.add(task)
10✔
855
        self._tasks.add(task)
10✔
856
        task.add_done_callback(task_done)
10✔
857
        return task
10✔
858

859
    def start_soon(
10✔
860
        self,
861
        func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
862
        *args: Unpack[PosArgsT],
863
        name: object = None,
864
    ) -> None:
865
        self._spawn(func, args, name)
10✔
866

867
    async def start(
10✔
868
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
869
    ) -> Any:
870
        future: asyncio.Future = asyncio.Future()
10✔
871
        task = self._spawn(func, args, name, future)
10✔
872

873
        # If the task raises an exception after sending a start value without a switch
874
        # point between, the task group is cancelled and this method never proceeds to
875
        # process the completed future. That's why we have to have a shielded cancel
876
        # scope here.
877
        try:
10✔
878
            return await future
10✔
879
        except CancelledError:
9✔
880
            # Cancel the task and wait for it to exit before returning
881
            task.cancel()
9✔
882
            with CancelScope(shield=True), suppress(CancelledError):
9✔
883
                await task
9✔
884

885
            raise
9✔
886

887

888
#
889
# Threads
890
#
891

892
_Retval_Queue_Type = tuple[Optional[T_Retval], Optional[BaseException]]
10✔
893

894

895
class WorkerThread(Thread):
10✔
896
    MAX_IDLE_TIME = 10  # seconds
10✔
897

898
    def __init__(
10✔
899
        self,
900
        root_task: asyncio.Task,
901
        workers: set[WorkerThread],
902
        idle_workers: deque[WorkerThread],
903
    ):
904
        super().__init__(name="AnyIO worker thread")
10✔
905
        self.root_task = root_task
10✔
906
        self.workers = workers
10✔
907
        self.idle_workers = idle_workers
10✔
908
        self.loop = root_task._loop
10✔
909
        self.queue: Queue[
10✔
910
            tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
911
        ] = Queue(2)
912
        self.idle_since = AsyncIOBackend.current_time()
10✔
913
        self.stopping = False
10✔
914

915
    def _report_result(
10✔
916
        self, future: asyncio.Future, result: Any, exc: BaseException | None
917
    ) -> None:
918
        self.idle_since = AsyncIOBackend.current_time()
10✔
919
        if not self.stopping:
10✔
920
            self.idle_workers.append(self)
10✔
921

922
        if not future.cancelled():
10✔
923
            if exc is not None:
10✔
924
                if isinstance(exc, StopIteration):
10✔
925
                    new_exc = RuntimeError("coroutine raised StopIteration")
9✔
926
                    new_exc.__cause__ = exc
9✔
927
                    exc = new_exc
9✔
928

929
                future.set_exception(exc)
10✔
930
            else:
931
                future.set_result(result)
10✔
932

933
    def run(self) -> None:
10✔
934
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
935
            while True:
7✔
936
                item = self.queue.get()
10✔
937
                if item is None:
10✔
938
                    # Shutdown command received
939
                    return
10✔
940

941
                context, func, args, future, cancel_scope = item
10✔
942
                if not future.cancelled():
10✔
943
                    result = None
10✔
944
                    exception: BaseException | None = None
10✔
945
                    threadlocals.current_cancel_scope = cancel_scope
10✔
946
                    try:
10✔
947
                        result = context.run(func, *args)
10✔
948
                    except BaseException as exc:
10✔
949
                        exception = exc
10✔
950
                    finally:
951
                        del threadlocals.current_cancel_scope
10✔
952

953
                    if not self.loop.is_closed():
10✔
954
                        self.loop.call_soon_threadsafe(
10✔
955
                            self._report_result, future, result, exception
956
                        )
957

958
                self.queue.task_done()
10✔
959

960
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
961
        self.stopping = True
10✔
962
        self.queue.put_nowait(None)
10✔
963
        self.workers.discard(self)
10✔
964
        try:
10✔
965
            self.idle_workers.remove(self)
10✔
966
        except ValueError:
9✔
967
            pass
9✔
968

969

970
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
10✔
971
    "_threadpool_idle_workers"
972
)
973
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
974

975

976
class BlockingPortal(abc.BlockingPortal):
10✔
977
    def __new__(cls) -> BlockingPortal:
10✔
978
        return object.__new__(cls)
10✔
979

980
    def __init__(self) -> None:
10✔
981
        super().__init__()
10✔
982
        self._loop = get_running_loop()
10✔
983

984
    def _spawn_task_from_thread(
10✔
985
        self,
986
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
987
        args: tuple[Unpack[PosArgsT]],
988
        kwargs: dict[str, Any],
989
        name: object,
990
        future: Future[T_Retval],
991
    ) -> None:
992
        AsyncIOBackend.run_sync_from_thread(
10✔
993
            partial(self._task_group.start_soon, name=name),
994
            (self._call_func, func, args, kwargs, future),
995
            self._loop,
996
        )
997

998

999
#
1000
# Subprocesses
1001
#
1002

1003

1004
@dataclass(eq=False)
10✔
1005
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
1006
    _stream: asyncio.StreamReader
10✔
1007

1008
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1009
        data = await self._stream.read(max_bytes)
9✔
1010
        if data:
9✔
1011
            return data
9✔
1012
        else:
1013
            raise EndOfStream
9✔
1014

1015
    async def aclose(self) -> None:
10✔
1016
        self._stream.set_exception(ClosedResourceError())
9✔
1017
        await AsyncIOBackend.checkpoint()
9✔
1018

1019

1020
@dataclass(eq=False)
10✔
1021
class StreamWriterWrapper(abc.ByteSendStream):
10✔
1022
    _stream: asyncio.StreamWriter
10✔
1023

1024
    async def send(self, item: bytes) -> None:
10✔
1025
        self._stream.write(item)
9✔
1026
        await self._stream.drain()
9✔
1027

1028
    async def aclose(self) -> None:
10✔
1029
        self._stream.close()
9✔
1030
        await AsyncIOBackend.checkpoint()
9✔
1031

1032

1033
@dataclass(eq=False)
10✔
1034
class Process(abc.Process):
10✔
1035
    _process: asyncio.subprocess.Process
10✔
1036
    _stdin: StreamWriterWrapper | None
10✔
1037
    _stdout: StreamReaderWrapper | None
10✔
1038
    _stderr: StreamReaderWrapper | None
10✔
1039

1040
    async def aclose(self) -> None:
10✔
1041
        with CancelScope(shield=True) as scope:
9✔
1042
            if self._stdin:
9✔
1043
                await self._stdin.aclose()
9✔
1044
            if self._stdout:
9✔
1045
                await self._stdout.aclose()
9✔
1046
            if self._stderr:
9✔
1047
                await self._stderr.aclose()
9✔
1048

1049
            scope.shield = False
9✔
1050
            try:
9✔
1051
                await self.wait()
9✔
1052
            except BaseException:
9✔
1053
                scope.shield = True
9✔
1054
                self.kill()
9✔
1055
                await self.wait()
9✔
1056
                raise
9✔
1057

1058
    async def wait(self) -> int:
10✔
1059
        return await self._process.wait()
9✔
1060

1061
    def terminate(self) -> None:
10✔
1062
        self._process.terminate()
9✔
1063

1064
    def kill(self) -> None:
10✔
1065
        self._process.kill()
9✔
1066

1067
    def send_signal(self, signal: int) -> None:
10✔
UNCOV
1068
        self._process.send_signal(signal)
×
1069

1070
    @property
10✔
1071
    def pid(self) -> int:
10✔
UNCOV
1072
        return self._process.pid
×
1073

1074
    @property
10✔
1075
    def returncode(self) -> int | None:
10✔
1076
        return self._process.returncode
9✔
1077

1078
    @property
10✔
1079
    def stdin(self) -> abc.ByteSendStream | None:
10✔
1080
        return self._stdin
9✔
1081

1082
    @property
10✔
1083
    def stdout(self) -> abc.ByteReceiveStream | None:
10✔
1084
        return self._stdout
9✔
1085

1086
    @property
10✔
1087
    def stderr(self) -> abc.ByteReceiveStream | None:
10✔
1088
        return self._stderr
9✔
1089

1090

1091
def _forcibly_shutdown_process_pool_on_exit(
10✔
1092
    workers: set[Process], _task: object
1093
) -> None:
1094
    """
1095
    Forcibly shuts down worker processes belonging to this event loop."""
1096
    child_watcher: asyncio.AbstractChildWatcher | None = None
9✔
1097
    if sys.version_info < (3, 12):
9✔
1098
        try:
5✔
1099
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
5✔
1100
        except NotImplementedError:
1✔
1101
            pass
1✔
1102

1103
    # Close as much as possible (w/o async/await) to avoid warnings
1104
    for process in workers:
9✔
1105
        if process.returncode is None:
9✔
1106
            continue
9✔
1107

UNCOV
1108
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
1109
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
1110
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
UNCOV
1111
        process.kill()
×
UNCOV
1112
        if child_watcher:
×
UNCOV
1113
            child_watcher.remove_child_handler(process.pid)
×
1114

1115

1116
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
1117
    """
1118
    Shuts down worker processes belonging to this event loop.
1119

1120
    NOTE: this only works when the event loop was started using asyncio.run() or
1121
    anyio.run().
1122

1123
    """
1124
    process: abc.Process
1125
    try:
9✔
1126
        await sleep(math.inf)
9✔
1127
    except asyncio.CancelledError:
9✔
1128
        for process in workers:
9✔
1129
            if process.returncode is None:
9✔
1130
                process.kill()
9✔
1131

1132
        for process in workers:
9✔
1133
            await process.aclose()
9✔
1134

1135

1136
#
1137
# Sockets and networking
1138
#
1139

1140

1141
class StreamProtocol(asyncio.Protocol):
10✔
1142
    read_queue: deque[bytes]
10✔
1143
    read_event: asyncio.Event
10✔
1144
    write_event: asyncio.Event
10✔
1145
    exception: Exception | None = None
10✔
1146
    is_at_eof: bool = False
10✔
1147

1148
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1149
        self.read_queue = deque()
10✔
1150
        self.read_event = asyncio.Event()
10✔
1151
        self.write_event = asyncio.Event()
10✔
1152
        self.write_event.set()
10✔
1153
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
1154

1155
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1156
        if exc:
10✔
1157
            self.exception = BrokenResourceError()
10✔
1158
            self.exception.__cause__ = exc
10✔
1159

1160
        self.read_event.set()
10✔
1161
        self.write_event.set()
10✔
1162

1163
    def data_received(self, data: bytes) -> None:
10✔
1164
        # ProactorEventloop sometimes sends bytearray instead of bytes
1165
        self.read_queue.append(bytes(data))
10✔
1166
        self.read_event.set()
10✔
1167

1168
    def eof_received(self) -> bool | None:
10✔
1169
        self.is_at_eof = True
10✔
1170
        self.read_event.set()
10✔
1171
        return True
10✔
1172

1173
    def pause_writing(self) -> None:
10✔
1174
        self.write_event = asyncio.Event()
10✔
1175

1176
    def resume_writing(self) -> None:
10✔
UNCOV
1177
        self.write_event.set()
×
1178

1179

1180
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
1181
    read_queue: deque[tuple[bytes, IPSockAddrType]]
10✔
1182
    read_event: asyncio.Event
10✔
1183
    write_event: asyncio.Event
10✔
1184
    exception: Exception | None = None
10✔
1185

1186
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
1187
        self.read_queue = deque(maxlen=100)  # arbitrary value
9✔
1188
        self.read_event = asyncio.Event()
9✔
1189
        self.write_event = asyncio.Event()
9✔
1190
        self.write_event.set()
9✔
1191

1192
    def connection_lost(self, exc: Exception | None) -> None:
10✔
1193
        self.read_event.set()
9✔
1194
        self.write_event.set()
9✔
1195

1196
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
1197
        addr = convert_ipv6_sockaddr(addr)
9✔
1198
        self.read_queue.append((data, addr))
9✔
1199
        self.read_event.set()
9✔
1200

1201
    def error_received(self, exc: Exception) -> None:
10✔
UNCOV
1202
        self.exception = exc
×
1203

1204
    def pause_writing(self) -> None:
10✔
UNCOV
1205
        self.write_event.clear()
×
1206

1207
    def resume_writing(self) -> None:
10✔
UNCOV
1208
        self.write_event.set()
×
1209

1210

1211
class SocketStream(abc.SocketStream):
10✔
1212
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
1213
        self._transport = transport
10✔
1214
        self._protocol = protocol
10✔
1215
        self._receive_guard = ResourceGuard("reading from")
10✔
1216
        self._send_guard = ResourceGuard("writing to")
10✔
1217
        self._closed = False
10✔
1218

1219
    @property
10✔
1220
    def _raw_socket(self) -> socket.socket:
10✔
1221
        return self._transport.get_extra_info("socket")
10✔
1222

1223
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1224
        with self._receive_guard:
10✔
1225
            if (
10✔
1226
                not self._protocol.read_event.is_set()
1227
                and not self._transport.is_closing()
1228
                and not self._protocol.is_at_eof
1229
            ):
1230
                self._transport.resume_reading()
10✔
1231
                await self._protocol.read_event.wait()
10✔
1232
                self._transport.pause_reading()
10✔
1233
            else:
1234
                await AsyncIOBackend.checkpoint()
10✔
1235

1236
            try:
10✔
1237
                chunk = self._protocol.read_queue.popleft()
10✔
1238
            except IndexError:
10✔
1239
                if self._closed:
10✔
1240
                    raise ClosedResourceError from None
10✔
1241
                elif self._protocol.exception:
10✔
1242
                    raise self._protocol.exception from None
10✔
1243
                else:
1244
                    raise EndOfStream from None
10✔
1245

1246
            if len(chunk) > max_bytes:
10✔
1247
                # Split the oversized chunk
1248
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
8✔
1249
                self._protocol.read_queue.appendleft(leftover)
8✔
1250

1251
            # If the read queue is empty, clear the flag so that the next call will
1252
            # block until data is available
1253
            if not self._protocol.read_queue:
10✔
1254
                self._protocol.read_event.clear()
10✔
1255

1256
        return chunk
10✔
1257

1258
    async def send(self, item: bytes) -> None:
10✔
1259
        with self._send_guard:
10✔
1260
            await AsyncIOBackend.checkpoint()
10✔
1261

1262
            if self._closed:
10✔
1263
                raise ClosedResourceError
10✔
1264
            elif self._protocol.exception is not None:
10✔
1265
                raise self._protocol.exception
10✔
1266

1267
            try:
10✔
1268
                self._transport.write(item)
10✔
UNCOV
1269
            except RuntimeError as exc:
×
UNCOV
1270
                if self._transport.is_closing():
×
UNCOV
1271
                    raise BrokenResourceError from exc
×
1272
                else:
UNCOV
1273
                    raise
×
1274

1275
            await self._protocol.write_event.wait()
10✔
1276

1277
    async def send_eof(self) -> None:
10✔
1278
        try:
10✔
1279
            self._transport.write_eof()
10✔
UNCOV
1280
        except OSError:
×
UNCOV
1281
            pass
×
1282

1283
    async def aclose(self) -> None:
10✔
1284
        if not self._transport.is_closing():
10✔
1285
            self._closed = True
10✔
1286
            try:
10✔
1287
                self._transport.write_eof()
10✔
1288
            except OSError:
6✔
1289
                pass
6✔
1290

1291
            self._transport.close()
10✔
1292
            await sleep(0)
10✔
1293
            self._transport.abort()
10✔
1294

1295

1296
class _RawSocketMixin:
10✔
1297
    _receive_future: asyncio.Future | None = None
10✔
1298
    _send_future: asyncio.Future | None = None
10✔
1299
    _closing = False
10✔
1300

1301
    def __init__(self, raw_socket: socket.socket):
10✔
1302
        self.__raw_socket = raw_socket
7✔
1303
        self._receive_guard = ResourceGuard("reading from")
7✔
1304
        self._send_guard = ResourceGuard("writing to")
7✔
1305

1306
    @property
10✔
1307
    def _raw_socket(self) -> socket.socket:
10✔
1308
        return self.__raw_socket
7✔
1309

1310
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1311
        def callback(f: object) -> None:
7✔
1312
            del self._receive_future
7✔
1313
            loop.remove_reader(self.__raw_socket)
7✔
1314

1315
        f = self._receive_future = asyncio.Future()
7✔
1316
        loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1317
        f.add_done_callback(callback)
7✔
1318
        return f
7✔
1319

1320
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1321
        def callback(f: object) -> None:
7✔
1322
            del self._send_future
7✔
1323
            loop.remove_writer(self.__raw_socket)
7✔
1324

1325
        f = self._send_future = asyncio.Future()
7✔
1326
        loop.add_writer(self.__raw_socket, f.set_result, None)
7✔
1327
        f.add_done_callback(callback)
7✔
1328
        return f
7✔
1329

1330
    async def aclose(self) -> None:
10✔
1331
        if not self._closing:
7✔
1332
            self._closing = True
7✔
1333
            if self.__raw_socket.fileno() != -1:
7✔
1334
                self.__raw_socket.close()
7✔
1335

1336
            if self._receive_future:
7✔
1337
                self._receive_future.set_result(None)
7✔
1338
            if self._send_future:
7✔
UNCOV
1339
                self._send_future.set_result(None)
×
1340

1341

1342
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1343
    async def send_eof(self) -> None:
10✔
1344
        with self._send_guard:
7✔
1345
            self._raw_socket.shutdown(socket.SHUT_WR)
7✔
1346

1347
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1348
        loop = get_running_loop()
7✔
1349
        await AsyncIOBackend.checkpoint()
7✔
1350
        with self._receive_guard:
7✔
1351
            while True:
5✔
1352
                try:
7✔
1353
                    data = self._raw_socket.recv(max_bytes)
7✔
1354
                except BlockingIOError:
7✔
1355
                    await self._wait_until_readable(loop)
7✔
1356
                except OSError as exc:
7✔
1357
                    if self._closing:
7✔
1358
                        raise ClosedResourceError from None
7✔
1359
                    else:
1360
                        raise BrokenResourceError from exc
2✔
1361
                else:
1362
                    if not data:
7✔
1363
                        raise EndOfStream
7✔
1364

1365
                    return data
7✔
1366

1367
    async def send(self, item: bytes) -> None:
10✔
1368
        loop = get_running_loop()
7✔
1369
        await AsyncIOBackend.checkpoint()
7✔
1370
        with self._send_guard:
7✔
1371
            view = memoryview(item)
7✔
1372
            while view:
7✔
1373
                try:
7✔
1374
                    bytes_sent = self._raw_socket.send(view)
7✔
1375
                except BlockingIOError:
7✔
1376
                    await self._wait_until_writable(loop)
7✔
1377
                except OSError as exc:
7✔
1378
                    if self._closing:
7✔
1379
                        raise ClosedResourceError from None
7✔
1380
                    else:
1381
                        raise BrokenResourceError from exc
2✔
1382
                else:
1383
                    view = view[bytes_sent:]
7✔
1384

1385
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1386
        if not isinstance(msglen, int) or msglen < 0:
7✔
1387
            raise ValueError("msglen must be a non-negative integer")
7✔
1388
        if not isinstance(maxfds, int) or maxfds < 1:
7✔
1389
            raise ValueError("maxfds must be a positive integer")
7✔
1390

1391
        loop = get_running_loop()
7✔
1392
        fds = array.array("i")
7✔
1393
        await AsyncIOBackend.checkpoint()
7✔
1394
        with self._receive_guard:
7✔
1395
            while True:
5✔
1396
                try:
7✔
1397
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
7✔
1398
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1399
                    )
1400
                except BlockingIOError:
7✔
1401
                    await self._wait_until_readable(loop)
7✔
UNCOV
1402
                except OSError as exc:
×
UNCOV
1403
                    if self._closing:
×
UNCOV
1404
                        raise ClosedResourceError from None
×
1405
                    else:
UNCOV
1406
                        raise BrokenResourceError from exc
×
1407
                else:
1408
                    if not message and not ancdata:
7✔
UNCOV
1409
                        raise EndOfStream
×
1410

1411
                    break
5✔
1412

1413
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
7✔
1414
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
7✔
UNCOV
1415
                raise RuntimeError(
×
1416
                    f"Received unexpected ancillary data; message = {message!r}, "
1417
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1418
                )
1419

1420
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
7✔
1421

1422
        return message, list(fds)
7✔
1423

1424
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1425
        if not message:
7✔
1426
            raise ValueError("message must not be empty")
7✔
1427
        if not fds:
7✔
1428
            raise ValueError("fds must not be empty")
7✔
1429

1430
        loop = get_running_loop()
7✔
1431
        filenos: list[int] = []
7✔
1432
        for fd in fds:
7✔
1433
            if isinstance(fd, int):
7✔
UNCOV
1434
                filenos.append(fd)
×
1435
            elif isinstance(fd, IOBase):
7✔
1436
                filenos.append(fd.fileno())
7✔
1437

1438
        fdarray = array.array("i", filenos)
7✔
1439
        await AsyncIOBackend.checkpoint()
7✔
1440
        with self._send_guard:
7✔
1441
            while True:
5✔
1442
                try:
7✔
1443
                    # The ignore can be removed after mypy picks up
1444
                    # https://github.com/python/typeshed/pull/5545
1445
                    self._raw_socket.sendmsg(
7✔
1446
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1447
                    )
1448
                    break
7✔
UNCOV
1449
                except BlockingIOError:
×
UNCOV
1450
                    await self._wait_until_writable(loop)
×
UNCOV
1451
                except OSError as exc:
×
UNCOV
1452
                    if self._closing:
×
UNCOV
1453
                        raise ClosedResourceError from None
×
1454
                    else:
UNCOV
1455
                        raise BrokenResourceError from exc
×
1456

1457

1458
class TCPSocketListener(abc.SocketListener):
10✔
1459
    _accept_scope: CancelScope | None = None
10✔
1460
    _closed = False
10✔
1461

1462
    def __init__(self, raw_socket: socket.socket):
10✔
1463
        self.__raw_socket = raw_socket
10✔
1464
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1465
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1466

1467
    @property
10✔
1468
    def _raw_socket(self) -> socket.socket:
10✔
1469
        return self.__raw_socket
10✔
1470

1471
    async def accept(self) -> abc.SocketStream:
10✔
1472
        if self._closed:
10✔
1473
            raise ClosedResourceError
10✔
1474

1475
        with self._accept_guard:
10✔
1476
            await AsyncIOBackend.checkpoint()
10✔
1477
            with CancelScope() as self._accept_scope:
10✔
1478
                try:
10✔
1479
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1480
                except asyncio.CancelledError:
10✔
1481
                    # Workaround for https://bugs.python.org/issue41317
1482
                    try:
10✔
1483
                        self._loop.remove_reader(self._raw_socket)
10✔
1484
                    except (ValueError, NotImplementedError):
2✔
1485
                        pass
2✔
1486

1487
                    if self._closed:
10✔
1488
                        raise ClosedResourceError from None
9✔
1489

1490
                    raise
10✔
1491
                finally:
1492
                    self._accept_scope = None
10✔
1493

1494
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1495
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1496
            StreamProtocol, client_sock
1497
        )
1498
        return SocketStream(transport, protocol)
10✔
1499

1500
    async def aclose(self) -> None:
10✔
1501
        if self._closed:
10✔
1502
            return
10✔
1503

1504
        self._closed = True
10✔
1505
        if self._accept_scope:
10✔
1506
            # Workaround for https://bugs.python.org/issue41317
1507
            try:
10✔
1508
                self._loop.remove_reader(self._raw_socket)
10✔
1509
            except (ValueError, NotImplementedError):
2✔
1510
                pass
2✔
1511

1512
            self._accept_scope.cancel()
9✔
1513
            await sleep(0)
9✔
1514

1515
        self._raw_socket.close()
10✔
1516

1517

1518
class UNIXSocketListener(abc.SocketListener):
10✔
1519
    def __init__(self, raw_socket: socket.socket):
10✔
1520
        self.__raw_socket = raw_socket
7✔
1521
        self._loop = get_running_loop()
7✔
1522
        self._accept_guard = ResourceGuard("accepting connections from")
7✔
1523
        self._closed = False
7✔
1524

1525
    async def accept(self) -> abc.SocketStream:
10✔
1526
        await AsyncIOBackend.checkpoint()
7✔
1527
        with self._accept_guard:
7✔
1528
            while True:
5✔
1529
                try:
7✔
1530
                    client_sock, _ = self.__raw_socket.accept()
7✔
1531
                    client_sock.setblocking(False)
7✔
1532
                    return UNIXSocketStream(client_sock)
7✔
1533
                except BlockingIOError:
7✔
1534
                    f: asyncio.Future = asyncio.Future()
7✔
1535
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1536
                    f.add_done_callback(
7✔
1537
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1538
                    )
1539
                    await f
7✔
UNCOV
1540
                except OSError as exc:
×
UNCOV
1541
                    if self._closed:
×
UNCOV
1542
                        raise ClosedResourceError from None
×
1543
                    else:
1544
                        raise BrokenResourceError from exc
2✔
1545

1546
    async def aclose(self) -> None:
10✔
1547
        self._closed = True
7✔
1548
        self.__raw_socket.close()
7✔
1549

1550
    @property
10✔
1551
    def _raw_socket(self) -> socket.socket:
10✔
1552
        return self.__raw_socket
7✔
1553

1554

1555
class UDPSocket(abc.UDPSocket):
10✔
1556
    def __init__(
10✔
1557
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1558
    ):
1559
        self._transport = transport
9✔
1560
        self._protocol = protocol
9✔
1561
        self._receive_guard = ResourceGuard("reading from")
9✔
1562
        self._send_guard = ResourceGuard("writing to")
9✔
1563
        self._closed = False
9✔
1564

1565
    @property
10✔
1566
    def _raw_socket(self) -> socket.socket:
10✔
1567
        return self._transport.get_extra_info("socket")
9✔
1568

1569
    async def aclose(self) -> None:
10✔
1570
        if not self._transport.is_closing():
9✔
1571
            self._closed = True
9✔
1572
            self._transport.close()
9✔
1573

1574
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1575
        with self._receive_guard:
9✔
1576
            await AsyncIOBackend.checkpoint()
9✔
1577

1578
            # If the buffer is empty, ask for more data
1579
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1580
                self._protocol.read_event.clear()
9✔
1581
                await self._protocol.read_event.wait()
9✔
1582

1583
            try:
9✔
1584
                return self._protocol.read_queue.popleft()
9✔
1585
            except IndexError:
9✔
1586
                if self._closed:
9✔
1587
                    raise ClosedResourceError from None
9✔
1588
                else:
1589
                    raise BrokenResourceError from None
3✔
1590

1591
    async def send(self, item: UDPPacketType) -> None:
10✔
1592
        with self._send_guard:
9✔
1593
            await AsyncIOBackend.checkpoint()
9✔
1594
            await self._protocol.write_event.wait()
9✔
1595
            if self._closed:
9✔
1596
                raise ClosedResourceError
9✔
1597
            elif self._transport.is_closing():
9✔
1598
                raise BrokenResourceError
×
1599
            else:
1600
                self._transport.sendto(*item)
9✔
1601

1602

1603
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1604
    def __init__(
10✔
1605
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1606
    ):
1607
        self._transport = transport
9✔
1608
        self._protocol = protocol
9✔
1609
        self._receive_guard = ResourceGuard("reading from")
9✔
1610
        self._send_guard = ResourceGuard("writing to")
9✔
1611
        self._closed = False
9✔
1612

1613
    @property
10✔
1614
    def _raw_socket(self) -> socket.socket:
10✔
1615
        return self._transport.get_extra_info("socket")
9✔
1616

1617
    async def aclose(self) -> None:
10✔
1618
        if not self._transport.is_closing():
9✔
1619
            self._closed = True
9✔
1620
            self._transport.close()
9✔
1621

1622
    async def receive(self) -> bytes:
10✔
1623
        with self._receive_guard:
9✔
1624
            await AsyncIOBackend.checkpoint()
9✔
1625

1626
            # If the buffer is empty, ask for more data
1627
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1628
                self._protocol.read_event.clear()
9✔
1629
                await self._protocol.read_event.wait()
9✔
1630

1631
            try:
9✔
1632
                packet = self._protocol.read_queue.popleft()
9✔
1633
            except IndexError:
9✔
1634
                if self._closed:
9✔
1635
                    raise ClosedResourceError from None
9✔
1636
                else:
UNCOV
1637
                    raise BrokenResourceError from None
×
1638

1639
            return packet[0]
9✔
1640

1641
    async def send(self, item: bytes) -> None:
10✔
1642
        with self._send_guard:
9✔
1643
            await AsyncIOBackend.checkpoint()
9✔
1644
            await self._protocol.write_event.wait()
9✔
1645
            if self._closed:
9✔
1646
                raise ClosedResourceError
9✔
1647
            elif self._transport.is_closing():
9✔
UNCOV
1648
                raise BrokenResourceError
×
1649
            else:
1650
                self._transport.sendto(item)
9✔
1651

1652

1653
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1654
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1655
        loop = get_running_loop()
7✔
1656
        await AsyncIOBackend.checkpoint()
7✔
1657
        with self._receive_guard:
7✔
1658
            while True:
5✔
1659
                try:
7✔
1660
                    data = self._raw_socket.recvfrom(65536)
7✔
1661
                except BlockingIOError:
7✔
1662
                    await self._wait_until_readable(loop)
7✔
1663
                except OSError as exc:
7✔
1664
                    if self._closing:
7✔
1665
                        raise ClosedResourceError from None
7✔
1666
                    else:
1667
                        raise BrokenResourceError from exc
2✔
1668
                else:
1669
                    return data
7✔
1670

1671
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1672
        loop = get_running_loop()
7✔
1673
        await AsyncIOBackend.checkpoint()
7✔
1674
        with self._send_guard:
7✔
1675
            while True:
5✔
1676
                try:
7✔
1677
                    self._raw_socket.sendto(*item)
7✔
1678
                except BlockingIOError:
7✔
UNCOV
1679
                    await self._wait_until_writable(loop)
×
1680
                except OSError as exc:
7✔
1681
                    if self._closing:
7✔
1682
                        raise ClosedResourceError from None
7✔
1683
                    else:
1684
                        raise BrokenResourceError from exc
2✔
1685
                else:
1686
                    return
7✔
1687

1688

1689
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1690
    async def receive(self) -> bytes:
10✔
1691
        loop = get_running_loop()
7✔
1692
        await AsyncIOBackend.checkpoint()
7✔
1693
        with self._receive_guard:
7✔
1694
            while True:
5✔
1695
                try:
7✔
1696
                    data = self._raw_socket.recv(65536)
7✔
1697
                except BlockingIOError:
7✔
1698
                    await self._wait_until_readable(loop)
7✔
1699
                except OSError as exc:
7✔
1700
                    if self._closing:
7✔
1701
                        raise ClosedResourceError from None
7✔
1702
                    else:
1703
                        raise BrokenResourceError from exc
2✔
1704
                else:
1705
                    return data
7✔
1706

1707
    async def send(self, item: bytes) -> None:
10✔
1708
        loop = get_running_loop()
7✔
1709
        await AsyncIOBackend.checkpoint()
7✔
1710
        with self._send_guard:
7✔
1711
            while True:
5✔
1712
                try:
7✔
1713
                    self._raw_socket.send(item)
7✔
1714
                except BlockingIOError:
7✔
UNCOV
1715
                    await self._wait_until_writable(loop)
×
1716
                except OSError as exc:
7✔
1717
                    if self._closing:
7✔
1718
                        raise ClosedResourceError from None
7✔
1719
                    else:
1720
                        raise BrokenResourceError from exc
2✔
1721
                else:
1722
                    return
7✔
1723

1724

1725
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1726
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1727

1728

1729
#
1730
# Synchronization
1731
#
1732

1733

1734
class Event(BaseEvent):
10✔
1735
    def __new__(cls) -> Event:
10✔
1736
        return object.__new__(cls)
10✔
1737

1738
    def __init__(self) -> None:
10✔
1739
        self._event = asyncio.Event()
10✔
1740

1741
    def set(self) -> None:
10✔
1742
        self._event.set()
10✔
1743

1744
    def is_set(self) -> bool:
10✔
1745
        return self._event.is_set()
10✔
1746

1747
    async def wait(self) -> None:
10✔
1748
        if self.is_set():
10✔
1749
            await AsyncIOBackend.checkpoint()
10✔
1750
        else:
1751
            await self._event.wait()
10✔
1752

1753
    def statistics(self) -> EventStatistics:
10✔
1754
        return EventStatistics(len(self._event._waiters))
9✔
1755

1756

1757
class Lock(BaseLock):
10✔
1758
    def __new__(cls, *, fast_acquire: bool = False) -> Lock:
10✔
1759
        return object.__new__(cls)
9✔
1760

1761
    def __init__(self, *, fast_acquire: bool = False) -> None:
10✔
1762
        self._fast_acquire = fast_acquire
9✔
1763
        self._owner_task: asyncio.Task | None = None
9✔
1764
        self._waiters: deque[tuple[asyncio.Task, asyncio.Future]] = deque()
9✔
1765

1766
    async def acquire(self) -> None:
10✔
1767
        task = cast(asyncio.Task, current_task())
9✔
1768
        if self._owner_task is None and not self._waiters:
9✔
1769
            await AsyncIOBackend.checkpoint_if_cancelled()
9✔
1770
            self._owner_task = task
9✔
1771

1772
            # Unless on the "fast path", yield control of the event loop so that other
1773
            # tasks can run too
1774
            if not self._fast_acquire:
9✔
1775
                try:
9✔
1776
                    await AsyncIOBackend.cancel_shielded_checkpoint()
9✔
1777
                except CancelledError:
9✔
1778
                    self.release()
9✔
1779
                    raise
9✔
1780

1781
            return
9✔
1782

1783
        if self._owner_task == task:
9✔
1784
            raise RuntimeError("Attempted to acquire an already held Lock")
9✔
1785

1786
        fut: asyncio.Future[None] = asyncio.Future()
9✔
1787
        item = task, fut
9✔
1788
        self._waiters.append(item)
9✔
1789
        try:
9✔
1790
            await fut
9✔
1791
        except CancelledError:
9✔
1792
            self._waiters.remove(item)
9✔
1793
            if self._owner_task is task:
9✔
1794
                self.release()
9✔
1795

1796
            raise
9✔
1797

1798
        self._waiters.remove(item)
9✔
1799

1800
    def acquire_nowait(self) -> None:
10✔
1801
        task = cast(asyncio.Task, current_task())
9✔
1802
        if self._owner_task is None and not self._waiters:
9✔
1803
            self._owner_task = task
9✔
1804
            return
9✔
1805

1806
        if self._owner_task is task:
9✔
1807
            raise RuntimeError("Attempted to acquire an already held Lock")
9✔
1808

1809
        raise WouldBlock
9✔
1810

1811
    def locked(self) -> bool:
10✔
1812
        return self._owner_task is not None
9✔
1813

1814
    def release(self) -> None:
10✔
1815
        if self._owner_task != current_task():
9✔
UNCOV
1816
            raise RuntimeError("The current task is not holding this lock")
×
1817

1818
        for task, fut in self._waiters:
9✔
1819
            if not fut.cancelled():
9✔
1820
                self._owner_task = task
9✔
1821
                fut.set_result(None)
9✔
1822
                return
9✔
1823

1824
        self._owner_task = None
9✔
1825

1826
    def statistics(self) -> LockStatistics:
10✔
1827
        task_info = AsyncIOTaskInfo(self._owner_task) if self._owner_task else None
9✔
1828
        return LockStatistics(self.locked(), task_info, len(self._waiters))
9✔
1829

1830

1831
class Semaphore(BaseSemaphore):
10✔
1832
    def __new__(
10✔
1833
        cls,
1834
        initial_value: int,
1835
        *,
1836
        max_value: int | None = None,
1837
        fast_acquire: bool = False,
1838
    ) -> Semaphore:
1839
        return object.__new__(cls)
9✔
1840

1841
    def __init__(
10✔
1842
        self,
1843
        initial_value: int,
1844
        *,
1845
        max_value: int | None = None,
1846
        fast_acquire: bool = False,
1847
    ):
1848
        super().__init__(initial_value, max_value=max_value)
9✔
1849
        self._value = initial_value
9✔
1850
        self._max_value = max_value
9✔
1851
        self._fast_acquire = fast_acquire
9✔
1852
        self._waiters: deque[asyncio.Future[None]] = deque()
9✔
1853

1854
    async def acquire(self) -> None:
10✔
1855
        if self._value > 0 and not self._waiters:
9✔
1856
            await AsyncIOBackend.checkpoint_if_cancelled()
9✔
1857
            self._value -= 1
9✔
1858

1859
            # Unless on the "fast path", yield control of the event loop so that other
1860
            # tasks can run too
1861
            if not self._fast_acquire:
9✔
1862
                try:
9✔
1863
                    await AsyncIOBackend.cancel_shielded_checkpoint()
9✔
1864
                except CancelledError:
9✔
1865
                    self.release()
9✔
1866
                    raise
9✔
1867

1868
            return
9✔
1869

1870
        fut: asyncio.Future[None] = asyncio.Future()
9✔
1871
        self._waiters.append(fut)
9✔
1872
        try:
9✔
1873
            await fut
9✔
1874
        except CancelledError:
9✔
1875
            try:
9✔
1876
                self._waiters.remove(fut)
9✔
1877
            except ValueError:
9✔
1878
                self.release()
9✔
1879

1880
            raise
9✔
1881

1882
    def acquire_nowait(self) -> None:
10✔
1883
        if self._value == 0:
9✔
1884
            raise WouldBlock
9✔
1885

1886
        self._value -= 1
9✔
1887

1888
    def release(self) -> None:
10✔
1889
        if self._max_value is not None and self._value == self._max_value:
9✔
1890
            raise ValueError("semaphore released too many times")
9✔
1891

1892
        for fut in self._waiters:
9✔
1893
            if not fut.cancelled():
9✔
1894
                fut.set_result(None)
9✔
1895
                self._waiters.remove(fut)
9✔
1896
                return
9✔
1897

1898
        self._value += 1
9✔
1899

1900
    @property
10✔
1901
    def value(self) -> int:
10✔
1902
        return self._value
9✔
1903

1904
    @property
10✔
1905
    def max_value(self) -> int | None:
10✔
1906
        return self._max_value
9✔
1907

1908
    def statistics(self) -> SemaphoreStatistics:
10✔
1909
        return SemaphoreStatistics(len(self._waiters))
9✔
1910

1911

1912
class CapacityLimiter(BaseCapacityLimiter):
10✔
1913
    _total_tokens: float = 0
10✔
1914

1915
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1916
        return object.__new__(cls)
10✔
1917

1918
    def __init__(self, total_tokens: float):
10✔
1919
        self._borrowers: set[Any] = set()
10✔
1920
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1921
        self.total_tokens = total_tokens
10✔
1922

1923
    async def __aenter__(self) -> None:
10✔
1924
        await self.acquire()
10✔
1925

1926
    async def __aexit__(
10✔
1927
        self,
1928
        exc_type: type[BaseException] | None,
1929
        exc_val: BaseException | None,
1930
        exc_tb: TracebackType | None,
1931
    ) -> None:
1932
        self.release()
10✔
1933

1934
    @property
10✔
1935
    def total_tokens(self) -> float:
10✔
1936
        return self._total_tokens
9✔
1937

1938
    @total_tokens.setter
10✔
1939
    def total_tokens(self, value: float) -> None:
10✔
1940
        if not isinstance(value, int) and not math.isinf(value):
10✔
1941
            raise TypeError("total_tokens must be an int or math.inf")
9✔
1942
        if value < 1:
10✔
1943
            raise ValueError("total_tokens must be >= 1")
9✔
1944

1945
        waiters_to_notify = max(value - self._total_tokens, 0)
10✔
1946
        self._total_tokens = value
10✔
1947

1948
        # Notify waiting tasks that they have acquired the limiter
1949
        while self._wait_queue and waiters_to_notify:
10✔
1950
            event = self._wait_queue.popitem(last=False)[1]
9✔
1951
            event.set()
9✔
1952
            waiters_to_notify -= 1
9✔
1953

1954
    @property
10✔
1955
    def borrowed_tokens(self) -> int:
10✔
1956
        return len(self._borrowers)
9✔
1957

1958
    @property
10✔
1959
    def available_tokens(self) -> float:
10✔
1960
        return self._total_tokens - len(self._borrowers)
9✔
1961

1962
    def acquire_nowait(self) -> None:
10✔
UNCOV
1963
        self.acquire_on_behalf_of_nowait(current_task())
×
1964

1965
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
1966
        if borrower in self._borrowers:
10✔
1967
            raise RuntimeError(
9✔
1968
                "this borrower is already holding one of this CapacityLimiter's "
1969
                "tokens"
1970
            )
1971

1972
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
1973
            raise WouldBlock
9✔
1974

1975
        self._borrowers.add(borrower)
10✔
1976

1977
    async def acquire(self) -> None:
10✔
1978
        return await self.acquire_on_behalf_of(current_task())
10✔
1979

1980
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
1981
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1982
        try:
10✔
1983
            self.acquire_on_behalf_of_nowait(borrower)
10✔
1984
        except WouldBlock:
9✔
1985
            event = asyncio.Event()
9✔
1986
            self._wait_queue[borrower] = event
9✔
1987
            try:
9✔
1988
                await event.wait()
9✔
UNCOV
1989
            except BaseException:
×
UNCOV
1990
                self._wait_queue.pop(borrower, None)
×
UNCOV
1991
                raise
×
1992

1993
            self._borrowers.add(borrower)
9✔
1994
        else:
1995
            try:
10✔
1996
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1997
            except BaseException:
9✔
1998
                self.release()
9✔
1999
                raise
9✔
2000

2001
    def release(self) -> None:
10✔
2002
        self.release_on_behalf_of(current_task())
10✔
2003

2004
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
2005
        try:
10✔
2006
            self._borrowers.remove(borrower)
10✔
2007
        except KeyError:
9✔
2008
            raise RuntimeError(
9✔
2009
                "this borrower isn't holding any of this CapacityLimiter's tokens"
2010
            ) from None
2011

2012
        # Notify the next task in line if this limiter has free capacity now
2013
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
2014
            event = self._wait_queue.popitem(last=False)[1]
9✔
2015
            event.set()
9✔
2016

2017
    def statistics(self) -> CapacityLimiterStatistics:
10✔
2018
        return CapacityLimiterStatistics(
9✔
2019
            self.borrowed_tokens,
2020
            self.total_tokens,
2021
            tuple(self._borrowers),
2022
            len(self._wait_queue),
2023
        )
2024

2025

2026
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
2027

2028

2029
#
2030
# Operating system signals
2031
#
2032

2033

2034
class _SignalReceiver:
10✔
2035
    def __init__(self, signals: tuple[Signals, ...]):
10✔
2036
        self._signals = signals
8✔
2037
        self._loop = get_running_loop()
8✔
2038
        self._signal_queue: deque[Signals] = deque()
8✔
2039
        self._future: asyncio.Future = asyncio.Future()
8✔
2040
        self._handled_signals: set[Signals] = set()
8✔
2041

2042
    def _deliver(self, signum: Signals) -> None:
10✔
2043
        self._signal_queue.append(signum)
8✔
2044
        if not self._future.done():
8✔
2045
            self._future.set_result(None)
8✔
2046

2047
    def __enter__(self) -> _SignalReceiver:
10✔
2048
        for sig in set(self._signals):
8✔
2049
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
2050
            self._handled_signals.add(sig)
8✔
2051

2052
        return self
8✔
2053

2054
    def __exit__(
10✔
2055
        self,
2056
        exc_type: type[BaseException] | None,
2057
        exc_val: BaseException | None,
2058
        exc_tb: TracebackType | None,
2059
    ) -> bool | None:
2060
        for sig in self._handled_signals:
8✔
2061
            self._loop.remove_signal_handler(sig)
8✔
2062
        return None
8✔
2063

2064
    def __aiter__(self) -> _SignalReceiver:
10✔
2065
        return self
8✔
2066

2067
    async def __anext__(self) -> Signals:
10✔
2068
        await AsyncIOBackend.checkpoint()
8✔
2069
        if not self._signal_queue:
8✔
UNCOV
2070
            self._future = asyncio.Future()
×
UNCOV
2071
            await self._future
×
2072

2073
        return self._signal_queue.popleft()
8✔
2074

2075

2076
#
2077
# Testing and debugging
2078
#
2079

2080

2081
class AsyncIOTaskInfo(TaskInfo):
10✔
2082
    def __init__(self, task: asyncio.Task):
10✔
2083
        task_state = _task_states.get(task)
10✔
2084
        if task_state is None:
10✔
2085
            parent_id = None
10✔
2086
        else:
2087
            parent_id = task_state.parent_id
10✔
2088

2089
        super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
10✔
2090
        self._task = weakref.ref(task)
10✔
2091

2092
    def has_pending_cancellation(self) -> bool:
10✔
2093
        if not (task := self._task()):
10✔
2094
            # If the task isn't around anymore, it won't have a pending cancellation
UNCOV
2095
            return False
×
2096

2097
        if sys.version_info >= (3, 11):
10✔
2098
            if task.cancelling():
5✔
2099
                return True
5✔
2100
        elif (
5✔
2101
            isinstance(task._fut_waiter, asyncio.Future)
2102
            and task._fut_waiter.cancelled()
2103
        ):
2104
            return True
5✔
2105

2106
        if task_state := _task_states.get(task):
10✔
2107
            if cancel_scope := task_state.cancel_scope:
10✔
2108
                return cancel_scope._effectively_cancelled
10✔
2109

2110
        return False
10✔
2111

2112

2113
class TestRunner(abc.TestRunner):
10✔
2114
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
2115

2116
    def __init__(
10✔
2117
        self,
2118
        *,
2119
        debug: bool | None = None,
2120
        use_uvloop: bool = False,
2121
        loop_factory: Callable[[], AbstractEventLoop] | None = None,
2122
    ) -> None:
2123
        if use_uvloop and loop_factory is None:
10✔
UNCOV
2124
            import uvloop
×
2125

2126
            loop_factory = uvloop.new_event_loop
×
2127

2128
        self._runner = Runner(debug=debug, loop_factory=loop_factory)
10✔
2129
        self._exceptions: list[BaseException] = []
10✔
2130
        self._runner_task: asyncio.Task | None = None
10✔
2131

2132
    def __enter__(self) -> TestRunner:
10✔
2133
        self._runner.__enter__()
10✔
2134
        self.get_loop().set_exception_handler(self._exception_handler)
10✔
2135
        return self
10✔
2136

2137
    def __exit__(
10✔
2138
        self,
2139
        exc_type: type[BaseException] | None,
2140
        exc_val: BaseException | None,
2141
        exc_tb: TracebackType | None,
2142
    ) -> None:
2143
        self._runner.__exit__(exc_type, exc_val, exc_tb)
10✔
2144

2145
    def get_loop(self) -> AbstractEventLoop:
10✔
2146
        return self._runner.get_loop()
10✔
2147

2148
    def _exception_handler(
10✔
2149
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
2150
    ) -> None:
2151
        if isinstance(context.get("exception"), Exception):
10✔
2152
            self._exceptions.append(context["exception"])
10✔
2153
        else:
2154
            loop.default_exception_handler(context)
10✔
2155

2156
    def _raise_async_exceptions(self) -> None:
10✔
2157
        # Re-raise any exceptions raised in asynchronous callbacks
2158
        if self._exceptions:
10✔
2159
            exceptions, self._exceptions = self._exceptions, []
10✔
2160
            if len(exceptions) == 1:
10✔
2161
                raise exceptions[0]
10✔
UNCOV
2162
            elif exceptions:
×
UNCOV
2163
                raise BaseExceptionGroup(
×
2164
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
2165
                )
2166

2167
    async def _run_tests_and_fixtures(
10✔
2168
        self,
2169
        receive_stream: MemoryObjectReceiveStream[
2170
            tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
2171
        ],
2172
    ) -> None:
2173
        from _pytest.outcomes import OutcomeException
10✔
2174

2175
        with receive_stream, self._send_stream:
10✔
2176
            async for coro, future in receive_stream:
10✔
2177
                try:
10✔
2178
                    retval = await coro
10✔
2179
                except CancelledError as exc:
10✔
2180
                    if not future.cancelled():
×
UNCOV
2181
                        future.cancel(*exc.args)
×
2182

UNCOV
2183
                    raise
×
2184
                except BaseException as exc:
10✔
2185
                    if not future.cancelled():
10✔
2186
                        future.set_exception(exc)
10✔
2187

2188
                    if not isinstance(exc, (Exception, OutcomeException)):
10✔
UNCOV
2189
                        raise
×
2190
                else:
2191
                    if not future.cancelled():
10✔
2192
                        future.set_result(retval)
10✔
2193

2194
    async def _call_in_runner_task(
10✔
2195
        self,
2196
        func: Callable[P, Awaitable[T_Retval]],
2197
        *args: P.args,
2198
        **kwargs: P.kwargs,
2199
    ) -> T_Retval:
2200
        if not self._runner_task:
10✔
2201
            self._send_stream, receive_stream = create_memory_object_stream[
10✔
2202
                tuple[Awaitable[Any], asyncio.Future]
2203
            ](1)
2204
            self._runner_task = self.get_loop().create_task(
10✔
2205
                self._run_tests_and_fixtures(receive_stream)
2206
            )
2207

2208
        coro = func(*args, **kwargs)
10✔
2209
        future: asyncio.Future[T_Retval] = self.get_loop().create_future()
10✔
2210
        self._send_stream.send_nowait((coro, future))
10✔
2211
        return await future
10✔
2212

2213
    def run_asyncgen_fixture(
10✔
2214
        self,
2215
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
2216
        kwargs: dict[str, Any],
2217
    ) -> Iterable[T_Retval]:
2218
        asyncgen = fixture_func(**kwargs)
10✔
2219
        fixturevalue: T_Retval = self.get_loop().run_until_complete(
10✔
2220
            self._call_in_runner_task(asyncgen.asend, None)
2221
        )
2222
        self._raise_async_exceptions()
10✔
2223

2224
        yield fixturevalue
10✔
2225

2226
        try:
10✔
2227
            self.get_loop().run_until_complete(
10✔
2228
                self._call_in_runner_task(asyncgen.asend, None)
2229
            )
2230
        except StopAsyncIteration:
10✔
2231
            self._raise_async_exceptions()
10✔
2232
        else:
UNCOV
2233
            self.get_loop().run_until_complete(asyncgen.aclose())
×
UNCOV
2234
            raise RuntimeError("Async generator fixture did not stop")
×
2235

2236
    def run_fixture(
10✔
2237
        self,
2238
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
2239
        kwargs: dict[str, Any],
2240
    ) -> T_Retval:
2241
        retval = self.get_loop().run_until_complete(
10✔
2242
            self._call_in_runner_task(fixture_func, **kwargs)
2243
        )
2244
        self._raise_async_exceptions()
10✔
2245
        return retval
10✔
2246

2247
    def run_test(
10✔
2248
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
2249
    ) -> None:
2250
        try:
10✔
2251
            self.get_loop().run_until_complete(
10✔
2252
                self._call_in_runner_task(test_func, **kwargs)
2253
            )
2254
        except Exception as exc:
10✔
2255
            self._exceptions.append(exc)
10✔
2256

2257
        self._raise_async_exceptions()
10✔
2258

2259

2260
class AsyncIOBackend(AsyncBackend):
10✔
2261
    @classmethod
10✔
2262
    def run(
10✔
2263
        cls,
2264
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
2265
        args: tuple[Unpack[PosArgsT]],
2266
        kwargs: dict[str, Any],
2267
        options: dict[str, Any],
2268
    ) -> T_Retval:
2269
        @wraps(func)
10✔
2270
        async def wrapper() -> T_Retval:
10✔
2271
            task = cast(asyncio.Task, current_task())
10✔
2272
            task.set_name(get_callable_name(func))
10✔
2273
            _task_states[task] = TaskState(None, None)
10✔
2274

2275
            try:
10✔
2276
                return await func(*args)
10✔
2277
            finally:
2278
                del _task_states[task]
10✔
2279

2280
        debug = options.get("debug", None)
10✔
2281
        loop_factory = options.get("loop_factory", None)
10✔
2282
        if loop_factory is None and options.get("use_uvloop", False):
10✔
2283
            import uvloop
7✔
2284

2285
            loop_factory = uvloop.new_event_loop
7✔
2286

2287
        with Runner(debug=debug, loop_factory=loop_factory) as runner:
10✔
2288
            return runner.run(wrapper())
10✔
2289

2290
    @classmethod
10✔
2291
    def current_token(cls) -> object:
10✔
2292
        return get_running_loop()
10✔
2293

2294
    @classmethod
10✔
2295
    def current_time(cls) -> float:
10✔
2296
        return get_running_loop().time()
10✔
2297

2298
    @classmethod
10✔
2299
    def cancelled_exception_class(cls) -> type[BaseException]:
10✔
2300
        return CancelledError
10✔
2301

2302
    @classmethod
10✔
2303
    async def checkpoint(cls) -> None:
10✔
2304
        await sleep(0)
10✔
2305

2306
    @classmethod
10✔
2307
    async def checkpoint_if_cancelled(cls) -> None:
10✔
2308
        task = current_task()
10✔
2309
        if task is None:
10✔
UNCOV
2310
            return
×
2311

2312
        try:
10✔
2313
            cancel_scope = _task_states[task].cancel_scope
10✔
2314
        except KeyError:
10✔
2315
            return
10✔
2316

2317
        while cancel_scope:
10✔
2318
            if cancel_scope.cancel_called:
10✔
2319
                await sleep(0)
10✔
2320
            elif cancel_scope.shield:
10✔
2321
                break
9✔
2322
            else:
2323
                cancel_scope = cancel_scope._parent_scope
10✔
2324

2325
    @classmethod
10✔
2326
    async def cancel_shielded_checkpoint(cls) -> None:
10✔
2327
        with CancelScope(shield=True):
10✔
2328
            await sleep(0)
10✔
2329

2330
    @classmethod
10✔
2331
    async def sleep(cls, delay: float) -> None:
10✔
2332
        await sleep(delay)
10✔
2333

2334
    @classmethod
10✔
2335
    def create_cancel_scope(
10✔
2336
        cls, *, deadline: float = math.inf, shield: bool = False
2337
    ) -> CancelScope:
2338
        return CancelScope(deadline=deadline, shield=shield)
10✔
2339

2340
    @classmethod
10✔
2341
    def current_effective_deadline(cls) -> float:
10✔
2342
        try:
9✔
2343
            cancel_scope = _task_states[
9✔
2344
                current_task()  # type: ignore[index]
2345
            ].cancel_scope
UNCOV
2346
        except KeyError:
×
UNCOV
2347
            return math.inf
×
2348

2349
        deadline = math.inf
9✔
2350
        while cancel_scope:
9✔
2351
            deadline = min(deadline, cancel_scope.deadline)
9✔
2352
            if cancel_scope._cancel_called:
9✔
2353
                deadline = -math.inf
9✔
2354
                break
9✔
2355
            elif cancel_scope.shield:
9✔
2356
                break
9✔
2357
            else:
2358
                cancel_scope = cancel_scope._parent_scope
9✔
2359

2360
        return deadline
9✔
2361

2362
    @classmethod
10✔
2363
    def create_task_group(cls) -> abc.TaskGroup:
10✔
2364
        return TaskGroup()
10✔
2365

2366
    @classmethod
10✔
2367
    def create_event(cls) -> abc.Event:
10✔
2368
        return Event()
10✔
2369

2370
    @classmethod
10✔
2371
    def create_lock(cls, *, fast_acquire: bool) -> abc.Lock:
10✔
2372
        return Lock(fast_acquire=fast_acquire)
9✔
2373

2374
    @classmethod
10✔
2375
    def create_semaphore(
10✔
2376
        cls,
2377
        initial_value: int,
2378
        *,
2379
        max_value: int | None = None,
2380
        fast_acquire: bool = False,
2381
    ) -> abc.Semaphore:
2382
        return Semaphore(initial_value, max_value=max_value, fast_acquire=fast_acquire)
9✔
2383

2384
    @classmethod
10✔
2385
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
10✔
2386
        return CapacityLimiter(total_tokens)
9✔
2387

2388
    @classmethod
10✔
2389
    async def run_sync_in_worker_thread(
10✔
2390
        cls,
2391
        func: Callable[[Unpack[PosArgsT]], T_Retval],
2392
        args: tuple[Unpack[PosArgsT]],
2393
        abandon_on_cancel: bool = False,
2394
        limiter: abc.CapacityLimiter | None = None,
2395
    ) -> T_Retval:
2396
        await cls.checkpoint()
10✔
2397

2398
        # If this is the first run in this event loop thread, set up the necessary
2399
        # variables
2400
        try:
10✔
2401
            idle_workers = _threadpool_idle_workers.get()
10✔
2402
            workers = _threadpool_workers.get()
10✔
2403
        except LookupError:
10✔
2404
            idle_workers = deque()
10✔
2405
            workers = set()
10✔
2406
            _threadpool_idle_workers.set(idle_workers)
10✔
2407
            _threadpool_workers.set(workers)
10✔
2408

2409
        async with limiter or cls.current_default_thread_limiter():
10✔
2410
            with CancelScope(shield=not abandon_on_cancel) as scope:
10✔
2411
                future: asyncio.Future = asyncio.Future()
10✔
2412
                root_task = find_root_task()
10✔
2413
                if not idle_workers:
10✔
2414
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2415
                    worker.start()
10✔
2416
                    workers.add(worker)
10✔
2417
                    root_task.add_done_callback(worker.stop)
10✔
2418
                else:
2419
                    worker = idle_workers.pop()
10✔
2420

2421
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2422
                    # seconds or longer
2423
                    now = cls.current_time()
10✔
2424
                    while idle_workers:
10✔
2425
                        if (
9✔
2426
                            now - idle_workers[0].idle_since
2427
                            < WorkerThread.MAX_IDLE_TIME
2428
                        ):
2429
                            break
9✔
2430

UNCOV
2431
                        expired_worker = idle_workers.popleft()
×
UNCOV
2432
                        expired_worker.root_task.remove_done_callback(
×
2433
                            expired_worker.stop
2434
                        )
UNCOV
2435
                        expired_worker.stop()
×
2436

2437
                context = copy_context()
10✔
2438
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2439
                if abandon_on_cancel or scope._parent_scope is None:
10✔
2440
                    worker_scope = scope
10✔
2441
                else:
2442
                    worker_scope = scope._parent_scope
10✔
2443

2444
                worker.queue.put_nowait((context, func, args, future, worker_scope))
10✔
2445
                return await future
10✔
2446

2447
    @classmethod
10✔
2448
    def check_cancelled(cls) -> None:
10✔
2449
        scope: CancelScope | None = threadlocals.current_cancel_scope
10✔
2450
        while scope is not None:
10✔
2451
            if scope.cancel_called:
10✔
2452
                raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
10✔
2453

2454
            if scope.shield:
10✔
UNCOV
2455
                return
×
2456

2457
            scope = scope._parent_scope
10✔
2458

2459
    @classmethod
10✔
2460
    def run_async_from_thread(
10✔
2461
        cls,
2462
        func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
2463
        args: tuple[Unpack[PosArgsT]],
2464
        token: object,
2465
    ) -> T_Retval:
2466
        async def task_wrapper(scope: CancelScope) -> T_Retval:
10✔
2467
            __tracebackhide__ = True
10✔
2468
            task = cast(asyncio.Task, current_task())
10✔
2469
            _task_states[task] = TaskState(None, scope)
10✔
2470
            scope._tasks.add(task)
10✔
2471
            try:
10✔
2472
                return await func(*args)
10✔
2473
            except CancelledError as exc:
10✔
2474
                raise concurrent.futures.CancelledError(str(exc)) from None
10✔
2475
            finally:
2476
                scope._tasks.discard(task)
10✔
2477

2478
        loop = cast(AbstractEventLoop, token)
10✔
2479
        context = copy_context()
10✔
2480
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
10✔
2481
        wrapper = task_wrapper(threadlocals.current_cancel_scope)
10✔
2482
        f: concurrent.futures.Future[T_Retval] = context.run(
10✔
2483
            asyncio.run_coroutine_threadsafe, wrapper, loop
2484
        )
2485
        return f.result()
10✔
2486

2487
    @classmethod
10✔
2488
    def run_sync_from_thread(
10✔
2489
        cls,
2490
        func: Callable[[Unpack[PosArgsT]], T_Retval],
2491
        args: tuple[Unpack[PosArgsT]],
2492
        token: object,
2493
    ) -> T_Retval:
2494
        @wraps(func)
10✔
2495
        def wrapper() -> None:
10✔
2496
            try:
10✔
2497
                sniffio.current_async_library_cvar.set("asyncio")
10✔
2498
                f.set_result(func(*args))
10✔
2499
            except BaseException as exc:
10✔
2500
                f.set_exception(exc)
10✔
2501
                if not isinstance(exc, Exception):
10✔
UNCOV
2502
                    raise
×
2503

2504
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2505
        loop = cast(AbstractEventLoop, token)
10✔
2506
        loop.call_soon_threadsafe(wrapper)
10✔
2507
        return f.result()
10✔
2508

2509
    @classmethod
10✔
2510
    def create_blocking_portal(cls) -> abc.BlockingPortal:
10✔
2511
        return BlockingPortal()
10✔
2512

2513
    @classmethod
10✔
2514
    async def open_process(
10✔
2515
        cls,
2516
        command: StrOrBytesPath | Sequence[StrOrBytesPath],
2517
        *,
2518
        stdin: int | IO[Any] | None,
2519
        stdout: int | IO[Any] | None,
2520
        stderr: int | IO[Any] | None,
2521
        **kwargs: Any,
2522
    ) -> Process:
2523
        await cls.checkpoint()
9✔
2524
        if isinstance(command, PathLike):
9✔
UNCOV
2525
            command = os.fspath(command)
×
2526

2527
        if isinstance(command, (str, bytes)):
9✔
2528
            process = await asyncio.create_subprocess_shell(
9✔
2529
                command,
2530
                stdin=stdin,
2531
                stdout=stdout,
2532
                stderr=stderr,
2533
                **kwargs,
2534
            )
2535
        else:
2536
            process = await asyncio.create_subprocess_exec(
9✔
2537
                *command,
2538
                stdin=stdin,
2539
                stdout=stdout,
2540
                stderr=stderr,
2541
                **kwargs,
2542
            )
2543

2544
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
9✔
2545
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
9✔
2546
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
9✔
2547
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
9✔
2548

2549
    @classmethod
10✔
2550
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
10✔
2551
        create_task(
9✔
2552
            _shutdown_process_pool_on_exit(workers),
2553
            name="AnyIO process pool shutdown task",
2554
        )
2555
        find_root_task().add_done_callback(
9✔
2556
            partial(_forcibly_shutdown_process_pool_on_exit, workers)  # type:ignore[arg-type]
2557
        )
2558

2559
    @classmethod
10✔
2560
    async def connect_tcp(
10✔
2561
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2562
    ) -> abc.SocketStream:
2563
        transport, protocol = cast(
10✔
2564
            tuple[asyncio.Transport, StreamProtocol],
2565
            await get_running_loop().create_connection(
2566
                StreamProtocol, host, port, local_addr=local_address
2567
            ),
2568
        )
2569
        transport.pause_reading()
10✔
2570
        return SocketStream(transport, protocol)
10✔
2571

2572
    @classmethod
10✔
2573
    async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
10✔
2574
        await cls.checkpoint()
7✔
2575
        loop = get_running_loop()
7✔
2576
        raw_socket = socket.socket(socket.AF_UNIX)
7✔
2577
        raw_socket.setblocking(False)
7✔
2578
        while True:
5✔
2579
            try:
7✔
2580
                raw_socket.connect(path)
7✔
2581
            except BlockingIOError:
7✔
UNCOV
2582
                f: asyncio.Future = asyncio.Future()
×
UNCOV
2583
                loop.add_writer(raw_socket, f.set_result, None)
×
UNCOV
2584
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
UNCOV
2585
                await f
×
2586
            except BaseException:
7✔
2587
                raw_socket.close()
7✔
2588
                raise
7✔
2589
            else:
2590
                return UNIXSocketStream(raw_socket)
7✔
2591

2592
    @classmethod
10✔
2593
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2594
        return TCPSocketListener(sock)
10✔
2595

2596
    @classmethod
10✔
2597
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
10✔
2598
        return UNIXSocketListener(sock)
7✔
2599

2600
    @classmethod
10✔
2601
    async def create_udp_socket(
10✔
2602
        cls,
2603
        family: AddressFamily,
2604
        local_address: IPSockAddrType | None,
2605
        remote_address: IPSockAddrType | None,
2606
        reuse_port: bool,
2607
    ) -> UDPSocket | ConnectedUDPSocket:
2608
        transport, protocol = await get_running_loop().create_datagram_endpoint(
9✔
2609
            DatagramProtocol,
2610
            local_addr=local_address,
2611
            remote_addr=remote_address,
2612
            family=family,
2613
            reuse_port=reuse_port,
2614
        )
2615
        if protocol.exception:
9✔
UNCOV
2616
            transport.close()
×
UNCOV
2617
            raise protocol.exception
×
2618

2619
        if not remote_address:
9✔
2620
            return UDPSocket(transport, protocol)
9✔
2621
        else:
2622
            return ConnectedUDPSocket(transport, protocol)
9✔
2623

2624
    @classmethod
10✔
2625
    async def create_unix_datagram_socket(  # type: ignore[override]
10✔
2626
        cls, raw_socket: socket.socket, remote_path: str | bytes | None
2627
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2628
        await cls.checkpoint()
7✔
2629
        loop = get_running_loop()
7✔
2630

2631
        if remote_path:
7✔
2632
            while True:
5✔
2633
                try:
7✔
2634
                    raw_socket.connect(remote_path)
7✔
UNCOV
2635
                except BlockingIOError:
×
2636
                    f: asyncio.Future = asyncio.Future()
×
2637
                    loop.add_writer(raw_socket, f.set_result, None)
×
2638
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2639
                    await f
×
UNCOV
2640
                except BaseException:
×
UNCOV
2641
                    raw_socket.close()
×
UNCOV
2642
                    raise
×
2643
                else:
2644
                    return ConnectedUNIXDatagramSocket(raw_socket)
7✔
2645
        else:
2646
            return UNIXDatagramSocket(raw_socket)
7✔
2647

2648
    @classmethod
10✔
2649
    async def getaddrinfo(
10✔
2650
        cls,
2651
        host: bytes | str | None,
2652
        port: str | int | None,
2653
        *,
2654
        family: int | AddressFamily = 0,
2655
        type: int | SocketKind = 0,
2656
        proto: int = 0,
2657
        flags: int = 0,
2658
    ) -> list[
2659
        tuple[
2660
            AddressFamily,
2661
            SocketKind,
2662
            int,
2663
            str,
2664
            tuple[str, int] | tuple[str, int, int, int],
2665
        ]
2666
    ]:
2667
        return await get_running_loop().getaddrinfo(
10✔
2668
            host, port, family=family, type=type, proto=proto, flags=flags
2669
        )
2670

2671
    @classmethod
10✔
2672
    async def getnameinfo(
10✔
2673
        cls, sockaddr: IPSockAddrType, flags: int = 0
2674
    ) -> tuple[str, str]:
2675
        return await get_running_loop().getnameinfo(sockaddr, flags)
9✔
2676

2677
    @classmethod
10✔
2678
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
10✔
UNCOV
2679
        await cls.checkpoint()
×
UNCOV
2680
        try:
×
UNCOV
2681
            read_events = _read_events.get()
×
UNCOV
2682
        except LookupError:
×
UNCOV
2683
            read_events = {}
×
UNCOV
2684
            _read_events.set(read_events)
×
2685

UNCOV
2686
        if read_events.get(sock):
×
UNCOV
2687
            raise BusyResourceError("reading from") from None
×
2688

2689
        loop = get_running_loop()
×
2690
        event = read_events[sock] = asyncio.Event()
×
2691
        loop.add_reader(sock, event.set)
×
2692
        try:
×
2693
            await event.wait()
×
2694
        finally:
2695
            if read_events.pop(sock, None) is not None:
×
2696
                loop.remove_reader(sock)
×
UNCOV
2697
                readable = True
×
2698
            else:
UNCOV
2699
                readable = False
×
2700

UNCOV
2701
        if not readable:
×
UNCOV
2702
            raise ClosedResourceError
×
2703

2704
    @classmethod
10✔
2705
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
10✔
UNCOV
2706
        await cls.checkpoint()
×
UNCOV
2707
        try:
×
UNCOV
2708
            write_events = _write_events.get()
×
UNCOV
2709
        except LookupError:
×
UNCOV
2710
            write_events = {}
×
UNCOV
2711
            _write_events.set(write_events)
×
2712

UNCOV
2713
        if write_events.get(sock):
×
UNCOV
2714
            raise BusyResourceError("writing to") from None
×
2715

UNCOV
2716
        loop = get_running_loop()
×
UNCOV
2717
        event = write_events[sock] = asyncio.Event()
×
UNCOV
2718
        loop.add_writer(sock.fileno(), event.set)
×
UNCOV
2719
        try:
×
UNCOV
2720
            await event.wait()
×
2721
        finally:
UNCOV
2722
            if write_events.pop(sock, None) is not None:
×
UNCOV
2723
                loop.remove_writer(sock)
×
UNCOV
2724
                writable = True
×
2725
            else:
UNCOV
2726
                writable = False
×
2727

UNCOV
2728
        if not writable:
×
UNCOV
2729
            raise ClosedResourceError
×
2730

2731
    @classmethod
10✔
2732
    def current_default_thread_limiter(cls) -> CapacityLimiter:
10✔
2733
        try:
10✔
2734
            return _default_thread_limiter.get()
10✔
2735
        except LookupError:
10✔
2736
            limiter = CapacityLimiter(40)
10✔
2737
            _default_thread_limiter.set(limiter)
10✔
2738
            return limiter
10✔
2739

2740
    @classmethod
10✔
2741
    def open_signal_receiver(
10✔
2742
        cls, *signals: Signals
2743
    ) -> AbstractContextManager[AsyncIterator[Signals]]:
2744
        return _SignalReceiver(signals)
8✔
2745

2746
    @classmethod
10✔
2747
    def get_current_task(cls) -> TaskInfo:
10✔
2748
        return AsyncIOTaskInfo(current_task())  # type: ignore[arg-type]
10✔
2749

2750
    @classmethod
10✔
2751
    def get_running_tasks(cls) -> Sequence[TaskInfo]:
10✔
2752
        return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()]
10✔
2753

2754
    @classmethod
10✔
2755
    async def wait_all_tasks_blocked(cls) -> None:
10✔
2756
        await cls.checkpoint()
10✔
2757
        this_task = current_task()
10✔
2758
        while True:
7✔
2759
            for task in all_tasks():
10✔
2760
                if task is this_task:
10✔
2761
                    continue
10✔
2762

2763
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2764
                if waiter is None or waiter.done():
10✔
2765
                    await sleep(0.1)
10✔
2766
                    break
10✔
2767
            else:
2768
                return
10✔
2769

2770
    @classmethod
10✔
2771
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
10✔
2772
        return TestRunner(**options)
10✔
2773

2774

2775
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc