• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 4985485255

pending completion
4985485255

push

github

GitHub
[pre-commit.ci] pre-commit autoupdate (#569)

4028 of 4451 relevant lines covered (90.5%)

9.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.01
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
11✔
2

3
import array
11✔
4
import asyncio
11✔
5
import concurrent.futures
11✔
6
import math
11✔
7
import socket
11✔
8
import sys
11✔
9
import threading
11✔
10
from asyncio import (
11✔
11
    AbstractEventLoop,
12
    CancelledError,
13
    all_tasks,
14
    create_task,
15
    current_task,
16
    get_running_loop,
17
    sleep,
18
)
19
from asyncio import run as native_run
11✔
20
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
11✔
21
from collections import OrderedDict, deque
11✔
22
from collections.abc import AsyncIterator, Iterable
11✔
23
from concurrent.futures import Future
11✔
24
from contextvars import Context, copy_context
11✔
25
from dataclasses import dataclass
11✔
26
from functools import partial, wraps
11✔
27
from inspect import CORO_RUNNING, CORO_SUSPENDED, getcoroutinestate
11✔
28
from io import IOBase
11✔
29
from os import PathLike
11✔
30
from queue import Queue
11✔
31
from signal import Signals
11✔
32
from socket import AddressFamily, SocketKind
11✔
33
from threading import Thread
11✔
34
from types import TracebackType
11✔
35
from typing import (
11✔
36
    IO,
37
    Any,
38
    AsyncGenerator,
39
    Awaitable,
40
    Callable,
41
    Collection,
42
    ContextManager,
43
    Coroutine,
44
    Deque,
45
    Generator,
46
    Iterator,
47
    Mapping,
48
    Optional,
49
    Sequence,
50
    Tuple,
51
    TypeVar,
52
    cast,
53
)
54
from weakref import WeakKeyDictionary
11✔
55

56
import sniffio
11✔
57

58
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
11✔
59
from .._core._eventloop import claim_worker_thread
11✔
60
from .._core._exceptions import (
11✔
61
    BrokenResourceError,
62
    BusyResourceError,
63
    ClosedResourceError,
64
    EndOfStream,
65
    WouldBlock,
66
)
67
from .._core._sockets import convert_ipv6_sockaddr
11✔
68
from .._core._streams import create_memory_object_stream
11✔
69
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
11✔
70
from .._core._synchronization import Event as BaseEvent
11✔
71
from .._core._synchronization import ResourceGuard
11✔
72
from .._core._tasks import CancelScope as BaseCancelScope
11✔
73
from ..abc import (
11✔
74
    AsyncBackend,
75
    IPSockAddrType,
76
    SocketListener,
77
    UDPPacketType,
78
    UNIXDatagramPacketType,
79
)
80
from ..lowlevel import RunVar
11✔
81
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
11✔
82

83
if sys.version_info < (3, 11):
11✔
84
    from exceptiongroup import BaseExceptionGroup, ExceptionGroup
7✔
85

86
if sys.version_info >= (3, 8):
11✔
87

88
    def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
8✔
89
        return task.get_coro()
8✔
90

91
else:
92

93
    def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
3✔
94
        return task._coro
3✔
95

96

97
T_Retval = TypeVar("T_Retval")
11✔
98
T_contra = TypeVar("T_contra", contravariant=True)
11✔
99

100
# Check whether there is native support for task names in asyncio (3.8+)
101
_native_task_names = hasattr(asyncio.Task, "get_name")
11✔
102

103
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
11✔
104

105

106
def find_root_task() -> asyncio.Task:
11✔
107
    root_task = _root_task.get(None)
11✔
108
    if root_task is not None and not root_task.done():
11✔
109
        return root_task
11✔
110

111
    # Look for a task that has been started via run_until_complete()
112
    for task in all_tasks():
11✔
113
        if task._callbacks and not task.done():
11✔
114
            callbacks = [cb for cb, context in task._callbacks]
11✔
115
            for cb in callbacks:
11✔
116
                if (
11✔
117
                    cb is _run_until_complete_cb
118
                    or getattr(cb, "__module__", None) == "uvloop.loop"
119
                ):
120
                    _root_task.set(task)
11✔
121
                    return task
11✔
122

123
    # Look up the topmost task in the AnyIO task tree, if possible
124
    task = cast(asyncio.Task, current_task())
10✔
125
    state = _task_states.get(task)
10✔
126
    if state:
10✔
127
        cancel_scope = state.cancel_scope
10✔
128
        while cancel_scope and cancel_scope._parent_scope is not None:
10✔
129
            cancel_scope = cancel_scope._parent_scope
×
130

131
        if cancel_scope is not None:
10✔
132
            return cast(asyncio.Task, cancel_scope._host_task)
10✔
133

134
    return task
×
135

136

137
def get_callable_name(func: Callable) -> str:
11✔
138
    module = getattr(func, "__module__", None)
11✔
139
    qualname = getattr(func, "__qualname__", None)
11✔
140
    return ".".join([x for x in (module, qualname) if x])
11✔
141

142

143
#
144
# Event loop
145
#
146

147
_run_vars = (
11✔
148
    WeakKeyDictionary()
149
)  # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any]
150

151

152
def _task_started(task: asyncio.Task) -> bool:
11✔
153
    """Return ``True`` if the task has been started and has not finished."""
154
    coro = cast(Coroutine[Any, Any, Any], get_coro(task))
11✔
155
    try:
11✔
156
        return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
11✔
157
    except AttributeError:
×
158
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
159
        raise Exception(f"Cannot determine if task {task} has started or not")
×
160

161

162
def _maybe_set_event_loop_policy(
11✔
163
    policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool
164
) -> None:
165
    # On CPython, use uvloop when possible if no other policy has been given and if not
166
    # explicitly disabled
167
    if policy is None and use_uvloop and sys.implementation.name == "cpython":
11✔
168
        try:
×
169
            import uvloop
×
170
        except ImportError:
×
171
            pass
×
172
        else:
173
            # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier)
174
            if not hasattr(
×
175
                asyncio.AbstractEventLoop, "shutdown_default_executor"
176
            ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"):
177
                policy = uvloop.EventLoopPolicy()
×
178

179
    if policy is not None:
11✔
180
        asyncio.set_event_loop_policy(policy)
11✔
181

182

183
#
184
# Timeouts and cancellation
185
#
186

187

188
class CancelScope(BaseCancelScope):
11✔
189
    def __new__(
11✔
190
        cls, *, deadline: float = math.inf, shield: bool = False
191
    ) -> CancelScope:
192
        return object.__new__(cls)
11✔
193

194
    def __init__(self, deadline: float = math.inf, shield: bool = False):
11✔
195
        self._deadline = deadline
11✔
196
        self._shield = shield
11✔
197
        self._parent_scope: CancelScope | None = None
11✔
198
        self._cancel_called = False
11✔
199
        self._active = False
11✔
200
        self._timeout_handle: asyncio.TimerHandle | None = None
11✔
201
        self._cancel_handle: asyncio.Handle | None = None
11✔
202
        self._tasks: set[asyncio.Task] = set()
11✔
203
        self._host_task: asyncio.Task | None = None
11✔
204
        self._timeout_expired = False
11✔
205
        self._cancel_calls: int = 0
11✔
206

207
    def __enter__(self) -> CancelScope:
11✔
208
        if self._active:
11✔
209
            raise RuntimeError(
×
210
                "Each CancelScope may only be used for a single 'with' block"
211
            )
212

213
        self._host_task = host_task = cast(asyncio.Task, current_task())
11✔
214
        self._tasks.add(host_task)
11✔
215
        try:
11✔
216
            task_state = _task_states[host_task]
11✔
217
        except KeyError:
11✔
218
            task_name = host_task.get_name() if _native_task_names else None
11✔
219
            task_state = TaskState(None, task_name, self)
11✔
220
            _task_states[host_task] = task_state
11✔
221
        else:
222
            self._parent_scope = task_state.cancel_scope
11✔
223
            task_state.cancel_scope = self
11✔
224

225
        self._timeout()
11✔
226
        self._active = True
11✔
227

228
        # Start cancelling the host task if the scope was cancelled before entering
229
        if self._cancel_called:
11✔
230
            self._deliver_cancellation()
11✔
231

232
        return self
11✔
233

234
    def __exit__(
11✔
235
        self,
236
        exc_type: type[BaseException] | None,
237
        exc_val: BaseException | None,
238
        exc_tb: TracebackType | None,
239
    ) -> bool | None:
240
        if not self._active:
11✔
241
            raise RuntimeError("This cancel scope is not active")
10✔
242
        if current_task() is not self._host_task:
11✔
243
            raise RuntimeError(
10✔
244
                "Attempted to exit cancel scope in a different task than it was "
245
                "entered in"
246
            )
247

248
        assert self._host_task is not None
11✔
249
        host_task_state = _task_states.get(self._host_task)
11✔
250
        if host_task_state is None or host_task_state.cancel_scope is not self:
11✔
251
            raise RuntimeError(
10✔
252
                "Attempted to exit a cancel scope that isn't the current tasks's "
253
                "current cancel scope"
254
            )
255

256
        self._active = False
11✔
257
        if self._timeout_handle:
11✔
258
            self._timeout_handle.cancel()
11✔
259
            self._timeout_handle = None
11✔
260

261
        self._tasks.remove(self._host_task)
11✔
262

263
        host_task_state.cancel_scope = self._parent_scope
11✔
264

265
        # Restart the cancellation effort in the farthest directly cancelled parent
266
        # scope if this one was shielded
267
        if self._shield:
11✔
268
            self._deliver_cancellation_to_parent()
11✔
269

270
        if exc_val is not None:
11✔
271
            exceptions = (
11✔
272
                exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
273
            )
274
            if all(isinstance(exc, CancelledError) for exc in exceptions):
11✔
275
                if self._timeout_expired:
11✔
276
                    return self._uncancel()
11✔
277
                elif not self._cancel_called:
11✔
278
                    # Task was cancelled natively
279
                    return None
11✔
280
                elif not self._parent_cancelled():
11✔
281
                    # This scope was directly cancelled
282
                    return self._uncancel()
11✔
283

284
        return None
11✔
285

286
    def _uncancel(self) -> bool:
11✔
287
        if sys.version_info < (3, 11) or self._host_task is None:
11✔
288
            self._cancel_calls = 0
9✔
289
            return True
9✔
290

291
        # Uncancel all AnyIO cancellations
292
        for i in range(self._cancel_calls):
4✔
293
            self._host_task.uncancel()
4✔
294

295
        self._cancel_calls = 0
4✔
296
        return not self._host_task.cancelling()
4✔
297

298
    def _timeout(self) -> None:
11✔
299
        if self._deadline != math.inf:
11✔
300
            loop = get_running_loop()
11✔
301
            if loop.time() >= self._deadline:
11✔
302
                self._timeout_expired = True
11✔
303
                self.cancel()
11✔
304
            else:
305
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
11✔
306

307
    def _deliver_cancellation(self) -> None:
11✔
308
        """
309
        Deliver cancellation to directly contained tasks and nested cancel scopes.
310

311
        Schedule another run at the end if we still have tasks eligible for
312
        cancellation.
313
        """
314
        should_retry = False
11✔
315
        current = current_task()
11✔
316
        for task in self._tasks:
11✔
317
            if task._must_cancel:  # type: ignore[attr-defined]
11✔
318
                continue
10✔
319

320
            # The task is eligible for cancellation if it has started and is not in a
321
            # cancel scope shielded from this one
322
            cancel_scope = _task_states[task].cancel_scope
11✔
323
            while cancel_scope is not self:
11✔
324
                if cancel_scope is None or cancel_scope._shield:
11✔
325
                    break
11✔
326
                else:
327
                    cancel_scope = cancel_scope._parent_scope
11✔
328
            else:
329
                should_retry = True
11✔
330
                if task is not current and (
11✔
331
                    task is self._host_task or _task_started(task)
332
                ):
333
                    self._cancel_calls += 1
11✔
334
                    task.cancel()
11✔
335

336
        # Schedule another callback if there are still tasks left
337
        if should_retry:
11✔
338
            self._cancel_handle = get_running_loop().call_soon(
11✔
339
                self._deliver_cancellation
340
            )
341
        else:
342
            self._cancel_handle = None
11✔
343

344
    def _deliver_cancellation_to_parent(self) -> None:
11✔
345
        """Start cancellation effort in the farthest directly cancelled parent scope"""
346
        scope = self._parent_scope
11✔
347
        scope_to_cancel: CancelScope | None = None
11✔
348
        while scope is not None:
11✔
349
            if scope._cancel_called and scope._cancel_handle is None:
11✔
350
                scope_to_cancel = scope
10✔
351

352
            # No point in looking beyond any shielded scope
353
            if scope._shield:
11✔
354
                break
10✔
355

356
            scope = scope._parent_scope
11✔
357

358
        if scope_to_cancel is not None:
11✔
359
            scope_to_cancel._deliver_cancellation()
10✔
360

361
    def _parent_cancelled(self) -> bool:
11✔
362
        # Check whether any parent has been cancelled
363
        cancel_scope = self._parent_scope
11✔
364
        while cancel_scope is not None and not cancel_scope._shield:
11✔
365
            if cancel_scope._cancel_called:
11✔
366
                return True
10✔
367
            else:
368
                cancel_scope = cancel_scope._parent_scope
11✔
369

370
        return False
11✔
371

372
    def cancel(self) -> None:
11✔
373
        if not self._cancel_called:
11✔
374
            if self._timeout_handle:
11✔
375
                self._timeout_handle.cancel()
11✔
376
                self._timeout_handle = None
11✔
377

378
            self._cancel_called = True
11✔
379
            if self._host_task is not None:
11✔
380
                self._deliver_cancellation()
11✔
381

382
    @property
11✔
383
    def deadline(self) -> float:
8✔
384
        return self._deadline
10✔
385

386
    @deadline.setter
11✔
387
    def deadline(self, value: float) -> None:
8✔
388
        self._deadline = float(value)
10✔
389
        if self._timeout_handle is not None:
10✔
390
            self._timeout_handle.cancel()
10✔
391
            self._timeout_handle = None
10✔
392

393
        if self._active and not self._cancel_called:
10✔
394
            self._timeout()
10✔
395

396
    @property
11✔
397
    def cancel_called(self) -> bool:
8✔
398
        return self._cancel_called
11✔
399

400
    @property
11✔
401
    def shield(self) -> bool:
8✔
402
        return self._shield
11✔
403

404
    @shield.setter
11✔
405
    def shield(self, value: bool) -> None:
8✔
406
        if self._shield != value:
10✔
407
            self._shield = value
10✔
408
            if not value:
10✔
409
                self._deliver_cancellation_to_parent()
10✔
410

411

412
#
413
# Task states
414
#
415

416

417
class TaskState:
11✔
418
    """
419
    Encapsulates auxiliary task information that cannot be added to the Task instance
420
    itself because there are no guarantees about its implementation.
421
    """
422

423
    __slots__ = "parent_id", "name", "cancel_scope"
11✔
424

425
    def __init__(
11✔
426
        self, parent_id: int | None, name: str | None, cancel_scope: CancelScope | None
427
    ):
428
        self.parent_id = parent_id
11✔
429
        self.name = name
11✔
430
        self.cancel_scope = cancel_scope
11✔
431

432

433
_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
11✔
434

435

436
#
437
# Task groups
438
#
439

440

441
class _AsyncioTaskStatus(abc.TaskStatus):
11✔
442
    def __init__(self, future: asyncio.Future, parent_id: int):
11✔
443
        self._future = future
11✔
444
        self._parent_id = parent_id
11✔
445

446
    def started(self, value: T_contra | None = None) -> None:
11✔
447
        try:
11✔
448
            self._future.set_result(value)
11✔
449
        except asyncio.InvalidStateError:
10✔
450
            raise RuntimeError(
10✔
451
                "called 'started' twice on the same task status"
452
            ) from None
453

454
        task = cast(asyncio.Task, current_task())
11✔
455
        _task_states[task].parent_id = self._parent_id
11✔
456

457

458
def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException:
11✔
459
    exceptions = list(excgroup.exceptions)
11✔
460
    modified = False
11✔
461
    for i, exc in enumerate(exceptions):
11✔
462
        if isinstance(exc, BaseExceptionGroup):
11✔
463
            new_exc = collapse_exception_group(exc)
10✔
464
            if new_exc is not exc:
10✔
465
                modified = True
10✔
466
                exceptions[i] = new_exc
10✔
467

468
    if len(exceptions) == 1:
11✔
469
        return exceptions[0]
11✔
470
    elif modified:
10✔
471
        return excgroup.derive(exceptions)
10✔
472
    else:
473
        return excgroup
10✔
474

475

476
def walk_exception_group(excgroup: BaseExceptionGroup) -> Iterator[BaseException]:
11✔
477
    for exc in excgroup.exceptions:
10✔
478
        if isinstance(exc, BaseExceptionGroup):
10✔
479
            yield from walk_exception_group(exc)
×
480
        else:
481
            yield exc
10✔
482

483

484
def is_anyio_cancelled_exc(exc: BaseException) -> bool:
11✔
485
    return isinstance(exc, CancelledError) and not exc.args
11✔
486

487

488
class TaskGroup(abc.TaskGroup):
11✔
489
    def __init__(self) -> None:
11✔
490
        self.cancel_scope: CancelScope = CancelScope()
11✔
491
        self._active = False
11✔
492
        self._exceptions: list[BaseException] = []
11✔
493

494
    async def __aenter__(self) -> TaskGroup:
11✔
495
        self.cancel_scope.__enter__()
11✔
496
        self._active = True
11✔
497
        return self
11✔
498

499
    async def __aexit__(
11✔
500
        self,
501
        exc_type: type[BaseException] | None,
502
        exc_val: BaseException | None,
503
        exc_tb: TracebackType | None,
504
    ) -> bool | None:
505
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
11✔
506
        if exc_val is not None:
11✔
507
            self.cancel_scope.cancel()
11✔
508
            self._exceptions.append(exc_val)
11✔
509

510
        while self.cancel_scope._tasks:
11✔
511
            try:
11✔
512
                await asyncio.wait(self.cancel_scope._tasks)
11✔
513
            except asyncio.CancelledError:
10✔
514
                self.cancel_scope.cancel()
10✔
515

516
        self._active = False
11✔
517
        if self._exceptions:
11✔
518
            exc: BaseException | None
519
            group = BaseExceptionGroup("multiple tasks failed", self._exceptions)
11✔
520
            if not self.cancel_scope._parent_cancelled():
11✔
521
                # If any exceptions other than AnyIO cancellation exceptions have been
522
                # received, raise those
523
                _, exc = group.split(is_anyio_cancelled_exc)
11✔
524
            elif all(is_anyio_cancelled_exc(e) for e in walk_exception_group(group)):
10✔
525
                # All tasks were cancelled by AnyIO
526
                exc = CancelledError()
10✔
527
            else:
528
                exc = group
10✔
529

530
            if isinstance(exc, BaseExceptionGroup):
11✔
531
                exc = collapse_exception_group(exc)
11✔
532

533
            if exc is not None and exc is not exc_val:
11✔
534
                raise exc
11✔
535

536
        return ignore_exception
11✔
537

538
    async def _run_wrapped_task(
11✔
539
        self, coro: Coroutine, task_status_future: asyncio.Future | None
540
    ) -> None:
541
        # This is the code path for Python 3.7 on which asyncio freaks out if a task
542
        # raises a BaseException.
543
        __traceback_hide__ = __tracebackhide__ = True  # noqa: F841
3✔
544
        task = cast(asyncio.Task, current_task())
3✔
545
        try:
3✔
546
            await coro
3✔
547
        except BaseException as exc:
3✔
548
            if task_status_future is None or task_status_future.done():
3✔
549
                self._exceptions.append(exc)
3✔
550
                self.cancel_scope.cancel()
3✔
551
            else:
552
                task_status_future.set_exception(exc)
3✔
553
        else:
554
            if task_status_future is not None and not task_status_future.done():
3✔
555
                task_status_future.set_exception(
3✔
556
                    RuntimeError("Child exited without calling task_status.started()")
557
                )
558
        finally:
559
            if task in self.cancel_scope._tasks:
3✔
560
                self.cancel_scope._tasks.remove(task)
3✔
561
                del _task_states[task]
3✔
562

563
    def _spawn(
11✔
564
        self,
565
        func: Callable[..., Awaitable[Any]],
566
        args: tuple,
567
        name: object,
568
        task_status_future: asyncio.Future | None = None,
569
    ) -> asyncio.Task:
570
        def task_done(_task: asyncio.Task) -> None:
11✔
571
            # This is the code path for Python 3.8+
572
            assert _task in self.cancel_scope._tasks
8✔
573
            self.cancel_scope._tasks.remove(_task)
8✔
574
            del _task_states[_task]
8✔
575

576
            try:
8✔
577
                exc = _task.exception()
8✔
578
            except CancelledError as e:
8✔
579
                while isinstance(e.__context__, CancelledError):
8✔
580
                    e = e.__context__
7✔
581

582
                exc = e
8✔
583

584
            if exc is not None:
8✔
585
                if task_status_future is None or task_status_future.done():
8✔
586
                    self._exceptions.append(exc)
8✔
587
                    self.cancel_scope.cancel()
8✔
588
                else:
589
                    task_status_future.set_exception(exc)
7✔
590
            elif task_status_future is not None and not task_status_future.done():
8✔
591
                task_status_future.set_exception(
7✔
592
                    RuntimeError("Child exited without calling task_status.started()")
593
                )
594

595
        if not self._active:
11✔
596
            raise RuntimeError(
11✔
597
                "This task group is not active; no new tasks can be started."
598
            )
599

600
        options: dict[str, Any] = {}
11✔
601
        name = get_callable_name(func) if name is None else str(name)
11✔
602
        if _native_task_names:
11✔
603
            options["name"] = name
8✔
604

605
        kwargs = {}
11✔
606
        if task_status_future:
11✔
607
            parent_id = id(current_task())
11✔
608
            kwargs["task_status"] = _AsyncioTaskStatus(
11✔
609
                task_status_future, id(self.cancel_scope._host_task)
610
            )
611
        else:
612
            parent_id = id(self.cancel_scope._host_task)
11✔
613

614
        coro = func(*args, **kwargs)
11✔
615
        if not asyncio.iscoroutine(coro):
11✔
616
            raise TypeError(
10✔
617
                f"Expected an async function, but {func} appears to be synchronous"
618
            )
619

620
        foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame")
11✔
621
        if foreign_coro or sys.version_info < (3, 8):
11✔
622
            coro = self._run_wrapped_task(coro, task_status_future)
3✔
623

624
        task = create_task(coro, **options)
11✔
625
        if not foreign_coro and sys.version_info >= (3, 8):
11✔
626
            task.add_done_callback(task_done)
8✔
627

628
        # Make the spawned task inherit the task group's cancel scope
629
        _task_states[task] = TaskState(
11✔
630
            parent_id=parent_id, name=name, cancel_scope=self.cancel_scope
631
        )
632
        self.cancel_scope._tasks.add(task)
11✔
633
        return task
11✔
634

635
    def start_soon(
11✔
636
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
637
    ) -> None:
638
        self._spawn(func, args, name)
11✔
639

640
    async def start(
11✔
641
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
642
    ) -> None:
643
        future: asyncio.Future = asyncio.Future()
11✔
644
        task = self._spawn(func, args, name, future)
11✔
645

646
        # If the task raises an exception after sending a start value without a switch
647
        # point between, the task group is cancelled and this method never proceeds to
648
        # process the completed future. That's why we have to have a shielded cancel
649
        # scope here.
650
        with CancelScope(shield=True):
11✔
651
            try:
11✔
652
                return await future
11✔
653
            except CancelledError:
10✔
654
                task.cancel()
10✔
655
                raise
10✔
656

657

658
#
659
# Threads
660
#
661

662
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
11✔
663

664

665
class WorkerThread(Thread):
11✔
666
    MAX_IDLE_TIME = 10  # seconds
11✔
667

668
    def __init__(
11✔
669
        self,
670
        root_task: asyncio.Task,
671
        workers: set[WorkerThread],
672
        idle_workers: Deque[WorkerThread],
673
    ):
674
        super().__init__(name="AnyIO worker thread")
11✔
675
        self.root_task = root_task
11✔
676
        self.workers = workers
11✔
677
        self.idle_workers = idle_workers
11✔
678
        self.loop = root_task._loop
11✔
679
        self.queue: Queue[
11✔
680
            tuple[Context, Callable, tuple, asyncio.Future] | None
681
        ] = Queue(2)
682
        self.idle_since = AsyncIOBackend.current_time()
11✔
683
        self.stopping = False
11✔
684

685
    def _report_result(
11✔
686
        self, future: asyncio.Future, result: Any, exc: BaseException | None
687
    ) -> None:
688
        self.idle_since = AsyncIOBackend.current_time()
11✔
689
        if not self.stopping:
11✔
690
            self.idle_workers.append(self)
11✔
691

692
        if not future.cancelled():
11✔
693
            if exc is not None:
11✔
694
                if isinstance(exc, StopIteration):
11✔
695
                    new_exc = RuntimeError("coroutine raised StopIteration")
10✔
696
                    new_exc.__cause__ = exc
10✔
697
                    exc = new_exc
10✔
698

699
                future.set_exception(exc)
11✔
700
            else:
701
                future.set_result(result)
11✔
702

703
    def run(self) -> None:
11✔
704
        with claim_worker_thread(AsyncIOBackend, self.loop):
11✔
705
            while True:
8✔
706
                item = self.queue.get()
11✔
707
                if item is None:
11✔
708
                    # Shutdown command received
709
                    return
11✔
710

711
                context, func, args, future = item
11✔
712
                if not future.cancelled():
11✔
713
                    result = None
11✔
714
                    exception: BaseException | None = None
11✔
715
                    try:
11✔
716
                        result = context.run(func, *args)
11✔
717
                    except BaseException as exc:
11✔
718
                        exception = exc
11✔
719

720
                    if not self.loop.is_closed():
11✔
721
                        self.loop.call_soon_threadsafe(
11✔
722
                            self._report_result, future, result, exception
723
                        )
724

725
                self.queue.task_done()
11✔
726

727
    def stop(self, f: asyncio.Task | None = None) -> None:
11✔
728
        self.stopping = True
11✔
729
        self.queue.put_nowait(None)
11✔
730
        self.workers.discard(self)
11✔
731
        try:
11✔
732
            self.idle_workers.remove(self)
11✔
733
        except ValueError:
10✔
734
            pass
10✔
735

736

737
_threadpool_idle_workers: RunVar[Deque[WorkerThread]] = RunVar(
11✔
738
    "_threadpool_idle_workers"
739
)
740
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
11✔
741

742

743
class BlockingPortal(abc.BlockingPortal):
11✔
744
    def __new__(cls) -> BlockingPortal:
11✔
745
        return object.__new__(cls)
11✔
746

747
    def __init__(self) -> None:
11✔
748
        super().__init__()
11✔
749
        self._loop = get_running_loop()
11✔
750

751
    def _spawn_task_from_thread(
11✔
752
        self,
753
        func: Callable,
754
        args: tuple[Any, ...],
755
        kwargs: dict[str, Any],
756
        name: object,
757
        future: Future,
758
    ) -> None:
759
        AsyncIOBackend.run_sync_from_thread(
11✔
760
            partial(self._task_group.start_soon, name=name),
761
            (self._call_func, func, args, kwargs, future),
762
            self._loop,
763
        )
764

765

766
#
767
# Subprocesses
768
#
769

770

771
@dataclass(eq=False)
11✔
772
class StreamReaderWrapper(abc.ByteReceiveStream):
11✔
773
    _stream: asyncio.StreamReader
11✔
774

775
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
776
        data = await self._stream.read(max_bytes)
9✔
777
        if data:
9✔
778
            return data
9✔
779
        else:
780
            raise EndOfStream
9✔
781

782
    async def aclose(self) -> None:
11✔
783
        self._stream.feed_eof()
9✔
784

785

786
@dataclass(eq=False)
11✔
787
class StreamWriterWrapper(abc.ByteSendStream):
11✔
788
    _stream: asyncio.StreamWriter
11✔
789

790
    async def send(self, item: bytes) -> None:
11✔
791
        self._stream.write(item)
9✔
792
        await self._stream.drain()
9✔
793

794
    async def aclose(self) -> None:
11✔
795
        self._stream.close()
9✔
796

797

798
@dataclass(eq=False)
11✔
799
class Process(abc.Process):
11✔
800
    _process: asyncio.subprocess.Process
11✔
801
    _stdin: StreamWriterWrapper | None
11✔
802
    _stdout: StreamReaderWrapper | None
11✔
803
    _stderr: StreamReaderWrapper | None
11✔
804

805
    async def aclose(self) -> None:
11✔
806
        if self._stdin:
9✔
807
            await self._stdin.aclose()
9✔
808
        if self._stdout:
9✔
809
            await self._stdout.aclose()
9✔
810
        if self._stderr:
9✔
811
            await self._stderr.aclose()
9✔
812

813
        await self.wait()
9✔
814

815
    async def wait(self) -> int:
11✔
816
        return await self._process.wait()
9✔
817

818
    def terminate(self) -> None:
11✔
819
        self._process.terminate()
8✔
820

821
    def kill(self) -> None:
11✔
822
        self._process.kill()
9✔
823

824
    def send_signal(self, signal: int) -> None:
11✔
825
        self._process.send_signal(signal)
×
826

827
    @property
11✔
828
    def pid(self) -> int:
8✔
829
        return self._process.pid
×
830

831
    @property
11✔
832
    def returncode(self) -> int | None:
8✔
833
        return self._process.returncode
9✔
834

835
    @property
11✔
836
    def stdin(self) -> abc.ByteSendStream | None:
8✔
837
        return self._stdin
9✔
838

839
    @property
11✔
840
    def stdout(self) -> abc.ByteReceiveStream | None:
8✔
841
        return self._stdout
9✔
842

843
    @property
11✔
844
    def stderr(self) -> abc.ByteReceiveStream | None:
8✔
845
        return self._stderr
9✔
846

847

848
def _forcibly_shutdown_process_pool_on_exit(
11✔
849
    workers: set[Process], _task: object
850
) -> None:
851
    """
852
    Forcibly shuts down worker processes belonging to this event loop."""
853
    child_watcher: asyncio.AbstractChildWatcher | None = None
9✔
854
    if sys.version_info < (3, 12):
9✔
855
        try:
8✔
856
            child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
8✔
857
        except NotImplementedError:
8✔
858
            pass
8✔
859

860
    # Close as much as possible (w/o async/await) to avoid warnings
861
    for process in workers:
9✔
862
        if process.returncode is None:
9✔
863
            continue
9✔
864

865
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
866
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
867
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
868
        process.kill()
×
869
        if child_watcher:
×
870
            child_watcher.remove_child_handler(process.pid)
×
871

872

873
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
11✔
874
    """
875
    Shuts down worker processes belonging to this event loop.
876

877
    NOTE: this only works when the event loop was started using asyncio.run() or
878
    anyio.run().
879

880
    """
881
    process: abc.Process
882
    try:
9✔
883
        await sleep(math.inf)
9✔
884
    except asyncio.CancelledError:
9✔
885
        for process in workers:
9✔
886
            if process.returncode is None:
9✔
887
                process.kill()
9✔
888

889
        for process in workers:
9✔
890
            await process.aclose()
9✔
891

892

893
#
894
# Sockets and networking
895
#
896

897

898
class StreamProtocol(asyncio.Protocol):
11✔
899
    read_queue: Deque[bytes]
11✔
900
    read_event: asyncio.Event
11✔
901
    write_event: asyncio.Event
11✔
902
    exception: Exception | None = None
11✔
903

904
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
11✔
905
        self.read_queue = deque()
11✔
906
        self.read_event = asyncio.Event()
11✔
907
        self.write_event = asyncio.Event()
11✔
908
        self.write_event.set()
11✔
909
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
11✔
910

911
    def connection_lost(self, exc: Exception | None) -> None:
11✔
912
        if exc:
11✔
913
            self.exception = BrokenResourceError()
11✔
914
            self.exception.__cause__ = exc
11✔
915

916
        self.read_event.set()
11✔
917
        self.write_event.set()
11✔
918

919
    def data_received(self, data: bytes) -> None:
11✔
920
        self.read_queue.append(data)
11✔
921
        self.read_event.set()
11✔
922

923
    def eof_received(self) -> bool | None:
11✔
924
        self.read_event.set()
11✔
925
        return True
11✔
926

927
    def pause_writing(self) -> None:
11✔
928
        self.write_event = asyncio.Event()
11✔
929

930
    def resume_writing(self) -> None:
11✔
931
        self.write_event.set()
×
932

933

934
class DatagramProtocol(asyncio.DatagramProtocol):
11✔
935
    read_queue: Deque[tuple[bytes, IPSockAddrType]]
11✔
936
    read_event: asyncio.Event
11✔
937
    write_event: asyncio.Event
11✔
938
    exception: Exception | None = None
11✔
939

940
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
11✔
941
        self.read_queue = deque(maxlen=100)  # arbitrary value
10✔
942
        self.read_event = asyncio.Event()
10✔
943
        self.write_event = asyncio.Event()
10✔
944
        self.write_event.set()
10✔
945

946
    def connection_lost(self, exc: Exception | None) -> None:
11✔
947
        self.read_event.set()
10✔
948
        self.write_event.set()
10✔
949

950
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
11✔
951
        addr = convert_ipv6_sockaddr(addr)
10✔
952
        self.read_queue.append((data, addr))
10✔
953
        self.read_event.set()
10✔
954

955
    def error_received(self, exc: Exception) -> None:
11✔
956
        self.exception = exc
×
957

958
    def pause_writing(self) -> None:
11✔
959
        self.write_event.clear()
×
960

961
    def resume_writing(self) -> None:
11✔
962
        self.write_event.set()
×
963

964

965
class SocketStream(abc.SocketStream):
11✔
966
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
11✔
967
        self._transport = transport
11✔
968
        self._protocol = protocol
11✔
969
        self._receive_guard = ResourceGuard("reading from")
11✔
970
        self._send_guard = ResourceGuard("writing to")
11✔
971
        self._closed = False
11✔
972

973
    @property
11✔
974
    def _raw_socket(self) -> socket.socket:
8✔
975
        return self._transport.get_extra_info("socket")
11✔
976

977
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
978
        with self._receive_guard:
11✔
979
            await AsyncIOBackend.checkpoint()
11✔
980

981
            if (
11✔
982
                not self._protocol.read_event.is_set()
983
                and not self._transport.is_closing()
984
            ):
985
                self._transport.resume_reading()
11✔
986
                await self._protocol.read_event.wait()
11✔
987
                self._transport.pause_reading()
11✔
988

989
            try:
11✔
990
                chunk = self._protocol.read_queue.popleft()
11✔
991
            except IndexError:
11✔
992
                if self._closed:
11✔
993
                    raise ClosedResourceError from None
11✔
994
                elif self._protocol.exception:
11✔
995
                    raise self._protocol.exception
11✔
996
                else:
997
                    raise EndOfStream from None
11✔
998

999
            if len(chunk) > max_bytes:
11✔
1000
                # Split the oversized chunk
1001
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
10✔
1002
                self._protocol.read_queue.appendleft(leftover)
10✔
1003

1004
            # If the read queue is empty, clear the flag so that the next call will
1005
            # block until data is available
1006
            if not self._protocol.read_queue:
11✔
1007
                self._protocol.read_event.clear()
11✔
1008

1009
        return chunk
11✔
1010

1011
    async def send(self, item: bytes) -> None:
11✔
1012
        with self._send_guard:
11✔
1013
            await AsyncIOBackend.checkpoint()
11✔
1014

1015
            if self._closed:
11✔
1016
                raise ClosedResourceError
11✔
1017
            elif self._protocol.exception is not None:
11✔
1018
                raise self._protocol.exception
11✔
1019

1020
            try:
11✔
1021
                self._transport.write(item)
11✔
1022
            except RuntimeError as exc:
×
1023
                if self._transport.is_closing():
×
1024
                    raise BrokenResourceError from exc
×
1025
                else:
1026
                    raise
×
1027

1028
            await self._protocol.write_event.wait()
11✔
1029

1030
    async def send_eof(self) -> None:
11✔
1031
        try:
11✔
1032
            self._transport.write_eof()
11✔
1033
        except OSError:
×
1034
            pass
×
1035

1036
    async def aclose(self) -> None:
11✔
1037
        if not self._transport.is_closing():
11✔
1038
            self._closed = True
11✔
1039
            try:
11✔
1040
                self._transport.write_eof()
11✔
1041
            except OSError:
5✔
1042
                pass
5✔
1043

1044
            self._transport.close()
11✔
1045
            await sleep(0)
11✔
1046
            self._transport.abort()
11✔
1047

1048

1049
class _RawSocketMixin:
11✔
1050
    _receive_future: asyncio.Future | None = None
11✔
1051
    _send_future: asyncio.Future | None = None
11✔
1052
    _closing = False
11✔
1053

1054
    def __init__(self, raw_socket: socket.socket):
11✔
1055
        self.__raw_socket = raw_socket
8✔
1056
        self._receive_guard = ResourceGuard("reading from")
8✔
1057
        self._send_guard = ResourceGuard("writing to")
8✔
1058

1059
    @property
11✔
1060
    def _raw_socket(self) -> socket.socket:
8✔
1061
        return self.__raw_socket
8✔
1062

1063
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
11✔
1064
        def callback(f: object) -> None:
8✔
1065
            del self._receive_future
8✔
1066
            loop.remove_reader(self.__raw_socket)
8✔
1067

1068
        f = self._receive_future = asyncio.Future()
8✔
1069
        loop.add_reader(self.__raw_socket, f.set_result, None)
8✔
1070
        f.add_done_callback(callback)
8✔
1071
        return f
8✔
1072

1073
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
11✔
1074
        def callback(f: object) -> None:
8✔
1075
            del self._send_future
8✔
1076
            loop.remove_writer(self.__raw_socket)
8✔
1077

1078
        f = self._send_future = asyncio.Future()
8✔
1079
        loop.add_writer(self.__raw_socket, f.set_result, None)
8✔
1080
        f.add_done_callback(callback)
8✔
1081
        return f
8✔
1082

1083
    async def aclose(self) -> None:
11✔
1084
        if not self._closing:
8✔
1085
            self._closing = True
8✔
1086
            if self.__raw_socket.fileno() != -1:
8✔
1087
                self.__raw_socket.close()
8✔
1088

1089
            if self._receive_future:
8✔
1090
                self._receive_future.set_result(None)
8✔
1091
            if self._send_future:
8✔
1092
                self._send_future.set_result(None)
×
1093

1094

1095
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
11✔
1096
    async def send_eof(self) -> None:
11✔
1097
        with self._send_guard:
8✔
1098
            self._raw_socket.shutdown(socket.SHUT_WR)
8✔
1099

1100
    async def receive(self, max_bytes: int = 65536) -> bytes:
11✔
1101
        loop = get_running_loop()
8✔
1102
        await AsyncIOBackend.checkpoint()
8✔
1103
        with self._receive_guard:
8✔
1104
            while True:
6✔
1105
                try:
8✔
1106
                    data = self._raw_socket.recv(max_bytes)
8✔
1107
                except BlockingIOError:
8✔
1108
                    await self._wait_until_readable(loop)
8✔
1109
                except OSError as exc:
8✔
1110
                    if self._closing:
8✔
1111
                        raise ClosedResourceError from None
8✔
1112
                    else:
1113
                        raise BrokenResourceError from exc
1✔
1114
                else:
1115
                    if not data:
8✔
1116
                        raise EndOfStream
8✔
1117

1118
                    return data
8✔
1119

1120
    async def send(self, item: bytes) -> None:
11✔
1121
        loop = get_running_loop()
8✔
1122
        await AsyncIOBackend.checkpoint()
8✔
1123
        with self._send_guard:
8✔
1124
            view = memoryview(item)
8✔
1125
            while view:
8✔
1126
                try:
8✔
1127
                    bytes_sent = self._raw_socket.send(item)
8✔
1128
                except BlockingIOError:
8✔
1129
                    await self._wait_until_writable(loop)
8✔
1130
                except OSError as exc:
8✔
1131
                    if self._closing:
8✔
1132
                        raise ClosedResourceError from None
8✔
1133
                    else:
1134
                        raise BrokenResourceError from exc
1✔
1135
                else:
1136
                    view = view[bytes_sent:]
8✔
1137

1138
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
11✔
1139
        if not isinstance(msglen, int) or msglen < 0:
8✔
1140
            raise ValueError("msglen must be a non-negative integer")
8✔
1141
        if not isinstance(maxfds, int) or maxfds < 1:
8✔
1142
            raise ValueError("maxfds must be a positive integer")
8✔
1143

1144
        loop = get_running_loop()
8✔
1145
        fds = array.array("i")
8✔
1146
        await AsyncIOBackend.checkpoint()
8✔
1147
        with self._receive_guard:
8✔
1148
            while True:
6✔
1149
                try:
8✔
1150
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
8✔
1151
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1152
                    )
1153
                except BlockingIOError:
6✔
1154
                    await self._wait_until_readable(loop)
6✔
1155
                except OSError as exc:
×
1156
                    if self._closing:
×
1157
                        raise ClosedResourceError from None
×
1158
                    else:
1159
                        raise BrokenResourceError from exc
×
1160
                else:
1161
                    if not message and not ancdata:
8✔
1162
                        raise EndOfStream
×
1163

1164
                    break
6✔
1165

1166
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
8✔
1167
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
8✔
1168
                raise RuntimeError(
×
1169
                    f"Received unexpected ancillary data; message = {message!r}, "
1170
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1171
                )
1172

1173
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
8✔
1174

1175
        return message, list(fds)
8✔
1176

1177
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
11✔
1178
        if not message:
8✔
1179
            raise ValueError("message must not be empty")
8✔
1180
        if not fds:
8✔
1181
            raise ValueError("fds must not be empty")
8✔
1182

1183
        loop = get_running_loop()
8✔
1184
        filenos: list[int] = []
8✔
1185
        for fd in fds:
8✔
1186
            if isinstance(fd, int):
8✔
1187
                filenos.append(fd)
×
1188
            elif isinstance(fd, IOBase):
8✔
1189
                filenos.append(fd.fileno())
8✔
1190

1191
        fdarray = array.array("i", filenos)
8✔
1192
        await AsyncIOBackend.checkpoint()
8✔
1193
        with self._send_guard:
8✔
1194
            while True:
6✔
1195
                try:
8✔
1196
                    # The ignore can be removed after mypy picks up
1197
                    # https://github.com/python/typeshed/pull/5545
1198
                    self._raw_socket.sendmsg(
8✔
1199
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1200
                    )
1201
                    break
8✔
1202
                except BlockingIOError:
×
1203
                    await self._wait_until_writable(loop)
×
1204
                except OSError as exc:
×
1205
                    if self._closing:
×
1206
                        raise ClosedResourceError from None
×
1207
                    else:
1208
                        raise BrokenResourceError from exc
×
1209

1210

1211
class TCPSocketListener(abc.SocketListener):
11✔
1212
    _accept_scope: CancelScope | None = None
11✔
1213
    _closed = False
11✔
1214

1215
    def __init__(self, raw_socket: socket.socket):
11✔
1216
        self.__raw_socket = raw_socket
11✔
1217
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
11✔
1218
        self._accept_guard = ResourceGuard("accepting connections from")
11✔
1219

1220
    @property
11✔
1221
    def _raw_socket(self) -> socket.socket:
8✔
1222
        return self.__raw_socket
11✔
1223

1224
    async def accept(self) -> abc.SocketStream:
11✔
1225
        if self._closed:
11✔
1226
            raise ClosedResourceError
11✔
1227

1228
        with self._accept_guard:
11✔
1229
            await AsyncIOBackend.checkpoint()
11✔
1230
            with CancelScope() as self._accept_scope:
11✔
1231
                try:
11✔
1232
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
11✔
1233
                except asyncio.CancelledError:
10✔
1234
                    # Workaround for https://bugs.python.org/issue41317
1235
                    try:
10✔
1236
                        self._loop.remove_reader(self._raw_socket)
10✔
1237
                    except (ValueError, NotImplementedError):
1✔
1238
                        pass
1✔
1239

1240
                    if self._closed:
10✔
1241
                        raise ClosedResourceError from None
10✔
1242

1243
                    raise
10✔
1244
                finally:
1245
                    self._accept_scope = None
11✔
1246

1247
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
11✔
1248
        transport, protocol = await self._loop.connect_accepted_socket(
11✔
1249
            StreamProtocol, client_sock
1250
        )
1251
        return SocketStream(transport, protocol)
11✔
1252

1253
    async def aclose(self) -> None:
11✔
1254
        if self._closed:
11✔
1255
            return
11✔
1256

1257
        self._closed = True
11✔
1258
        if self._accept_scope:
11✔
1259
            # Workaround for https://bugs.python.org/issue41317
1260
            try:
11✔
1261
                self._loop.remove_reader(self._raw_socket)
11✔
1262
            except (ValueError, NotImplementedError):
1✔
1263
                pass
1✔
1264

1265
            self._accept_scope.cancel()
10✔
1266
            await sleep(0)
10✔
1267

1268
        self._raw_socket.close()
11✔
1269

1270

1271
class UNIXSocketListener(abc.SocketListener):
11✔
1272
    def __init__(self, raw_socket: socket.socket):
11✔
1273
        self.__raw_socket = raw_socket
8✔
1274
        self._loop = get_running_loop()
8✔
1275
        self._accept_guard = ResourceGuard("accepting connections from")
8✔
1276
        self._closed = False
8✔
1277

1278
    async def accept(self) -> abc.SocketStream:
11✔
1279
        await AsyncIOBackend.checkpoint()
8✔
1280
        with self._accept_guard:
8✔
1281
            while True:
6✔
1282
                try:
8✔
1283
                    client_sock, _ = self.__raw_socket.accept()
8✔
1284
                    client_sock.setblocking(False)
8✔
1285
                    return UNIXSocketStream(client_sock)
8✔
1286
                except BlockingIOError:
8✔
1287
                    f: asyncio.Future = asyncio.Future()
8✔
1288
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
8✔
1289
                    f.add_done_callback(
8✔
1290
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1291
                    )
1292
                    await f
8✔
1293
                except OSError as exc:
×
1294
                    if self._closed:
×
1295
                        raise ClosedResourceError from None
×
1296
                    else:
1297
                        raise BrokenResourceError from exc
1✔
1298

1299
    async def aclose(self) -> None:
11✔
1300
        self._closed = True
8✔
1301
        self.__raw_socket.close()
8✔
1302

1303
    @property
11✔
1304
    def _raw_socket(self) -> socket.socket:
8✔
1305
        return self.__raw_socket
8✔
1306

1307

1308
class UDPSocket(abc.UDPSocket):
11✔
1309
    def __init__(
11✔
1310
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1311
    ):
1312
        self._transport = transport
10✔
1313
        self._protocol = protocol
10✔
1314
        self._receive_guard = ResourceGuard("reading from")
10✔
1315
        self._send_guard = ResourceGuard("writing to")
10✔
1316
        self._closed = False
10✔
1317

1318
    @property
11✔
1319
    def _raw_socket(self) -> socket.socket:
8✔
1320
        return self._transport.get_extra_info("socket")
10✔
1321

1322
    async def aclose(self) -> None:
11✔
1323
        if not self._transport.is_closing():
10✔
1324
            self._closed = True
10✔
1325
            self._transport.close()
10✔
1326

1327
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
11✔
1328
        with self._receive_guard:
10✔
1329
            await AsyncIOBackend.checkpoint()
10✔
1330

1331
            # If the buffer is empty, ask for more data
1332
            if not self._protocol.read_queue and not self._transport.is_closing():
10✔
1333
                self._protocol.read_event.clear()
10✔
1334
                await self._protocol.read_event.wait()
10✔
1335

1336
            try:
10✔
1337
                return self._protocol.read_queue.popleft()
10✔
1338
            except IndexError:
10✔
1339
                if self._closed:
10✔
1340
                    raise ClosedResourceError from None
10✔
1341
                else:
1342
                    raise BrokenResourceError from None
1✔
1343

1344
    async def send(self, item: UDPPacketType) -> None:
11✔
1345
        with self._send_guard:
10✔
1346
            await AsyncIOBackend.checkpoint()
10✔
1347
            await self._protocol.write_event.wait()
10✔
1348
            if self._closed:
10✔
1349
                raise ClosedResourceError
10✔
1350
            elif self._transport.is_closing():
10✔
1351
                raise BrokenResourceError
×
1352
            else:
1353
                self._transport.sendto(*item)
10✔
1354

1355

1356
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
11✔
1357
    def __init__(
11✔
1358
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1359
    ):
1360
        self._transport = transport
10✔
1361
        self._protocol = protocol
10✔
1362
        self._receive_guard = ResourceGuard("reading from")
10✔
1363
        self._send_guard = ResourceGuard("writing to")
10✔
1364
        self._closed = False
10✔
1365

1366
    @property
11✔
1367
    def _raw_socket(self) -> socket.socket:
8✔
1368
        return self._transport.get_extra_info("socket")
10✔
1369

1370
    async def aclose(self) -> None:
11✔
1371
        if not self._transport.is_closing():
10✔
1372
            self._closed = True
10✔
1373
            self._transport.close()
10✔
1374

1375
    async def receive(self) -> bytes:
11✔
1376
        with self._receive_guard:
10✔
1377
            await AsyncIOBackend.checkpoint()
10✔
1378

1379
            # If the buffer is empty, ask for more data
1380
            if not self._protocol.read_queue and not self._transport.is_closing():
10✔
1381
                self._protocol.read_event.clear()
10✔
1382
                await self._protocol.read_event.wait()
10✔
1383

1384
            try:
10✔
1385
                packet = self._protocol.read_queue.popleft()
10✔
1386
            except IndexError:
10✔
1387
                if self._closed:
10✔
1388
                    raise ClosedResourceError from None
10✔
1389
                else:
1390
                    raise BrokenResourceError from None
×
1391

1392
            return packet[0]
10✔
1393

1394
    async def send(self, item: bytes) -> None:
11✔
1395
        with self._send_guard:
10✔
1396
            await AsyncIOBackend.checkpoint()
10✔
1397
            await self._protocol.write_event.wait()
10✔
1398
            if self._closed:
10✔
1399
                raise ClosedResourceError
10✔
1400
            elif self._transport.is_closing():
10✔
1401
                raise BrokenResourceError
×
1402
            else:
1403
                self._transport.sendto(item)
10✔
1404

1405

1406
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
11✔
1407
    async def receive(self) -> UNIXDatagramPacketType:
11✔
1408
        loop = get_running_loop()
8✔
1409
        await AsyncIOBackend.checkpoint()
8✔
1410
        with self._receive_guard:
8✔
1411
            while True:
6✔
1412
                try:
8✔
1413
                    data = self._raw_socket.recvfrom(65536)
8✔
1414
                except BlockingIOError:
8✔
1415
                    await self._wait_until_readable(loop)
8✔
1416
                except OSError as exc:
8✔
1417
                    if self._closing:
8✔
1418
                        raise ClosedResourceError from None
8✔
1419
                    else:
1420
                        raise BrokenResourceError from exc
1✔
1421
                else:
1422
                    return data
8✔
1423

1424
    async def send(self, item: UNIXDatagramPacketType) -> None:
11✔
1425
        loop = get_running_loop()
8✔
1426
        await AsyncIOBackend.checkpoint()
8✔
1427
        with self._send_guard:
8✔
1428
            while True:
6✔
1429
                try:
8✔
1430
                    self._raw_socket.sendto(*item)
8✔
1431
                except BlockingIOError:
8✔
1432
                    await self._wait_until_writable(loop)
×
1433
                except OSError as exc:
8✔
1434
                    if self._closing:
8✔
1435
                        raise ClosedResourceError from None
8✔
1436
                    else:
1437
                        raise BrokenResourceError from exc
1✔
1438
                else:
1439
                    return
8✔
1440

1441

1442
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
11✔
1443
    async def receive(self) -> bytes:
11✔
1444
        loop = get_running_loop()
8✔
1445
        await AsyncIOBackend.checkpoint()
8✔
1446
        with self._receive_guard:
8✔
1447
            while True:
6✔
1448
                try:
8✔
1449
                    data = self._raw_socket.recv(65536)
8✔
1450
                except BlockingIOError:
8✔
1451
                    await self._wait_until_readable(loop)
8✔
1452
                except OSError as exc:
8✔
1453
                    if self._closing:
8✔
1454
                        raise ClosedResourceError from None
8✔
1455
                    else:
1456
                        raise BrokenResourceError from exc
1✔
1457
                else:
1458
                    return data
8✔
1459

1460
    async def send(self, item: bytes) -> None:
11✔
1461
        loop = get_running_loop()
8✔
1462
        await AsyncIOBackend.checkpoint()
8✔
1463
        with self._send_guard:
8✔
1464
            while True:
6✔
1465
                try:
8✔
1466
                    self._raw_socket.send(item)
8✔
1467
                except BlockingIOError:
8✔
1468
                    await self._wait_until_writable(loop)
×
1469
                except OSError as exc:
8✔
1470
                    if self._closing:
8✔
1471
                        raise ClosedResourceError from None
8✔
1472
                    else:
1473
                        raise BrokenResourceError from exc
1✔
1474
                else:
1475
                    return
8✔
1476

1477

1478
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
11✔
1479
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
11✔
1480

1481

1482
#
1483
# Synchronization
1484
#
1485

1486

1487
class Event(BaseEvent):
11✔
1488
    def __new__(cls) -> Event:
11✔
1489
        return object.__new__(cls)
11✔
1490

1491
    def __init__(self) -> None:
11✔
1492
        self._event = asyncio.Event()
11✔
1493

1494
    def set(self) -> None:
11✔
1495
        self._event.set()
11✔
1496

1497
    def is_set(self) -> bool:
11✔
1498
        return self._event.is_set()
10✔
1499

1500
    async def wait(self) -> None:
11✔
1501
        if await self._event.wait():
11✔
1502
            await AsyncIOBackend.checkpoint()
11✔
1503

1504
    def statistics(self) -> EventStatistics:
11✔
1505
        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
10✔
1506

1507

1508
class CapacityLimiter(BaseCapacityLimiter):
11✔
1509
    _total_tokens: float = 0
11✔
1510

1511
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
11✔
1512
        return object.__new__(cls)
11✔
1513

1514
    def __init__(self, total_tokens: float):
11✔
1515
        self._borrowers: set[Any] = set()
11✔
1516
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
11✔
1517
        self.total_tokens = total_tokens
11✔
1518

1519
    async def __aenter__(self) -> None:
11✔
1520
        await self.acquire()
11✔
1521

1522
    async def __aexit__(
11✔
1523
        self,
1524
        exc_type: type[BaseException] | None,
1525
        exc_val: BaseException | None,
1526
        exc_tb: TracebackType | None,
1527
    ) -> None:
1528
        self.release()
11✔
1529

1530
    @property
11✔
1531
    def total_tokens(self) -> float:
8✔
1532
        return self._total_tokens
10✔
1533

1534
    @total_tokens.setter
11✔
1535
    def total_tokens(self, value: float) -> None:
8✔
1536
        if not isinstance(value, int) and not math.isinf(value):
11✔
1537
            raise TypeError("total_tokens must be an int or math.inf")
10✔
1538
        if value < 1:
11✔
1539
            raise ValueError("total_tokens must be >= 1")
10✔
1540

1541
        old_value = self._total_tokens
11✔
1542
        self._total_tokens = value
11✔
1543
        events = []
11✔
1544
        for event in self._wait_queue.values():
11✔
1545
            if value <= old_value:
10✔
1546
                break
×
1547

1548
            if not event.is_set():
10✔
1549
                events.append(event)
10✔
1550
                old_value += 1
10✔
1551

1552
        for event in events:
11✔
1553
            event.set()
10✔
1554

1555
    @property
11✔
1556
    def borrowed_tokens(self) -> int:
8✔
1557
        return len(self._borrowers)
10✔
1558

1559
    @property
11✔
1560
    def available_tokens(self) -> float:
8✔
1561
        return self._total_tokens - len(self._borrowers)
10✔
1562

1563
    def acquire_nowait(self) -> None:
11✔
1564
        self.acquire_on_behalf_of_nowait(current_task())
×
1565

1566
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
11✔
1567
        if borrower in self._borrowers:
11✔
1568
            raise RuntimeError(
10✔
1569
                "this borrower is already holding one of this CapacityLimiter's "
1570
                "tokens"
1571
            )
1572

1573
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
11✔
1574
            raise WouldBlock
10✔
1575

1576
        self._borrowers.add(borrower)
11✔
1577

1578
    async def acquire(self) -> None:
11✔
1579
        return await self.acquire_on_behalf_of(current_task())
11✔
1580

1581
    async def acquire_on_behalf_of(self, borrower: object) -> None:
11✔
1582
        await AsyncIOBackend.checkpoint_if_cancelled()
11✔
1583
        try:
11✔
1584
            self.acquire_on_behalf_of_nowait(borrower)
11✔
1585
        except WouldBlock:
10✔
1586
            event = asyncio.Event()
10✔
1587
            self._wait_queue[borrower] = event
10✔
1588
            try:
10✔
1589
                await event.wait()
10✔
1590
            except BaseException:
×
1591
                self._wait_queue.pop(borrower, None)
×
1592
                raise
×
1593

1594
            self._borrowers.add(borrower)
10✔
1595
        else:
1596
            try:
11✔
1597
                await AsyncIOBackend.cancel_shielded_checkpoint()
11✔
1598
            except BaseException:
10✔
1599
                self.release()
10✔
1600
                raise
10✔
1601

1602
    def release(self) -> None:
11✔
1603
        self.release_on_behalf_of(current_task())
11✔
1604

1605
    def release_on_behalf_of(self, borrower: object) -> None:
11✔
1606
        try:
11✔
1607
            self._borrowers.remove(borrower)
11✔
1608
        except KeyError:
10✔
1609
            raise RuntimeError(
10✔
1610
                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1611
            ) from None
1612

1613
        # Notify the next task in line if this limiter has free capacity now
1614
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
11✔
1615
            event = self._wait_queue.popitem(last=False)[1]
10✔
1616
            event.set()
10✔
1617

1618
    def statistics(self) -> CapacityLimiterStatistics:
11✔
1619
        return CapacityLimiterStatistics(
10✔
1620
            self.borrowed_tokens,
1621
            self.total_tokens,
1622
            tuple(self._borrowers),
1623
            len(self._wait_queue),
1624
        )
1625

1626

1627
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
11✔
1628

1629

1630
#
1631
# Operating system signals
1632
#
1633

1634

1635
class _SignalReceiver:
11✔
1636
    def __init__(self, signals: tuple[Signals, ...]):
11✔
1637
        self._signals = signals
9✔
1638
        self._loop = get_running_loop()
9✔
1639
        self._signal_queue: Deque[Signals] = deque()
9✔
1640
        self._future: asyncio.Future = asyncio.Future()
9✔
1641
        self._handled_signals: set[Signals] = set()
9✔
1642

1643
    def _deliver(self, signum: Signals) -> None:
11✔
1644
        self._signal_queue.append(signum)
9✔
1645
        if not self._future.done():
9✔
1646
            self._future.set_result(None)
9✔
1647

1648
    def __enter__(self) -> _SignalReceiver:
11✔
1649
        for sig in set(self._signals):
9✔
1650
            self._loop.add_signal_handler(sig, self._deliver, sig)
9✔
1651
            self._handled_signals.add(sig)
9✔
1652

1653
        return self
9✔
1654

1655
    def __exit__(
11✔
1656
        self,
1657
        exc_type: type[BaseException] | None,
1658
        exc_val: BaseException | None,
1659
        exc_tb: TracebackType | None,
1660
    ) -> bool | None:
1661
        for sig in self._handled_signals:
9✔
1662
            self._loop.remove_signal_handler(sig)
9✔
1663
        return None
9✔
1664

1665
    def __aiter__(self) -> _SignalReceiver:
11✔
1666
        return self
9✔
1667

1668
    async def __anext__(self) -> Signals:
11✔
1669
        await AsyncIOBackend.checkpoint()
9✔
1670
        if not self._signal_queue:
9✔
1671
            self._future = asyncio.Future()
×
1672
            await self._future
×
1673

1674
        return self._signal_queue.popleft()
9✔
1675

1676

1677
#
1678
# Testing and debugging
1679
#
1680

1681

1682
def _create_task_info(task: asyncio.Task) -> TaskInfo:
11✔
1683
    task_state = _task_states.get(task)
11✔
1684
    if task_state is None:
11✔
1685
        name = task.get_name() if _native_task_names else None
11✔
1686
        parent_id = None
11✔
1687
    else:
1688
        name = task_state.name
11✔
1689
        parent_id = task_state.parent_id
11✔
1690

1691
    return TaskInfo(id(task), parent_id, name, get_coro(task))
11✔
1692

1693

1694
async def _shutdown_default_executor(loop: asyncio.BaseEventLoop) -> None:
11✔
1695
    """Schedule the shutdown of the default executor.
1696
    BaseEventLoop.shutdown_default_executor was introduced in Python 3.9.
1697
    This function is an adapted version of the method from Python 3.11.
1698
    It's used in TestRunner.close only if python < 3.9.
1699
    """
1700

1701
    def _do_shutdown(
4✔
1702
        loop_: asyncio.BaseEventLoop, future: asyncio.futures.Future
1703
    ) -> None:
1704
        try:
4✔
1705
            loop_._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
4✔
1706
            loop_.call_soon_threadsafe(future.set_result, None)
4✔
1707
        except Exception as ex:
×
1708
            loop_.call_soon_threadsafe(future.set_exception, ex)
×
1709

1710
    if loop._default_executor is None:  # type: ignore[attr-defined]
4✔
1711
        return
4✔
1712
    future = loop.create_future()
4✔
1713
    thread = threading.Thread(
4✔
1714
        target=_do_shutdown,
1715
        args=(
1716
            loop,
1717
            future,
1718
        ),
1719
    )
1720
    thread.start()
4✔
1721
    try:
4✔
1722
        await future
4✔
1723
    finally:
1724
        thread.join()
4✔
1725

1726

1727
class TestRunner(abc.TestRunner):
11✔
1728
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
11✔
1729

1730
    def __init__(
11✔
1731
        self,
1732
        debug: bool = False,
1733
        use_uvloop: bool = False,
1734
        policy: asyncio.AbstractEventLoopPolicy | None = None,
1735
    ):
1736
        self._exceptions: list[BaseException] = []
11✔
1737
        _maybe_set_event_loop_policy(policy, use_uvloop)
11✔
1738
        self._loop = asyncio.new_event_loop()
11✔
1739
        self._loop.set_debug(debug)
11✔
1740
        self._loop.set_exception_handler(self._exception_handler)
11✔
1741
        self._runner_task: asyncio.Task | None = None
11✔
1742
        asyncio.set_event_loop(self._loop)
11✔
1743

1744
    def _cancel_all_tasks(self) -> None:
11✔
1745
        to_cancel = all_tasks(self._loop)
11✔
1746
        if not to_cancel:
11✔
1747
            return
7✔
1748

1749
        for task in to_cancel:
11✔
1750
            task.cancel()
11✔
1751

1752
        self._loop.run_until_complete(
11✔
1753
            asyncio.gather(*to_cancel, return_exceptions=True)
1754
        )
1755

1756
        for task in to_cancel:
11✔
1757
            if task.cancelled():
11✔
1758
                continue
11✔
1759
            if task.exception() is not None:
9✔
1760
                raise cast(BaseException, task.exception())
×
1761

1762
    def _exception_handler(
11✔
1763
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
1764
    ) -> None:
1765
        if isinstance(context.get("exception"), Exception):
11✔
1766
            self._exceptions.append(context["exception"])
11✔
1767
        else:
1768
            loop.default_exception_handler(context)
11✔
1769

1770
    def _raise_async_exceptions(self) -> None:
11✔
1771
        # Re-raise any exceptions raised in asynchronous callbacks
1772
        if self._exceptions:
11✔
1773
            exceptions, self._exceptions = self._exceptions, []
11✔
1774
            if len(exceptions) == 1:
11✔
1775
                raise exceptions[0]
11✔
1776
            elif exceptions:
×
1777
                raise BaseExceptionGroup(
×
1778
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
1779
                )
1780

1781
    @staticmethod
11✔
1782
    async def _run_tests_and_fixtures(
8✔
1783
        receive_stream: MemoryObjectReceiveStream[
1784
            tuple[Coroutine[Any, Any, T_Retval], Future[T_Retval]]
1785
        ],
1786
    ) -> None:
1787
        with receive_stream:
11✔
1788
            async for coro, future in receive_stream:
11✔
1789
                try:
11✔
1790
                    retval = await coro
11✔
1791
                except BaseException as exc:
11✔
1792
                    if not future.cancelled():
11✔
1793
                        future.set_exception(exc)
11✔
1794
                else:
1795
                    if not future.cancelled():
11✔
1796
                        future.set_result(retval)
11✔
1797

1798
    async def _call_in_runner_task(
11✔
1799
        self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
1800
    ) -> T_Retval:
1801
        if not self._runner_task:
11✔
1802
            self._send_stream, receive_stream = create_memory_object_stream(1)
11✔
1803
            self._runner_task = self._loop.create_task(
11✔
1804
                self._run_tests_and_fixtures(receive_stream)
1805
            )
1806

1807
        coro = func(*args, **kwargs)
11✔
1808
        future: asyncio.Future[T_Retval] = self._loop.create_future()
11✔
1809
        self._send_stream.send_nowait((coro, future))
11✔
1810
        return await future
11✔
1811

1812
    def close(self) -> None:
11✔
1813
        try:
11✔
1814
            if self._runner_task is not None:
11✔
1815
                self._runner_task = None
11✔
1816
                self._loop.run_until_complete(self._send_stream.aclose())
11✔
1817
                del self._send_stream
11✔
1818

1819
            self._cancel_all_tasks()
11✔
1820
            self._loop.run_until_complete(self._loop.shutdown_asyncgens())
11✔
1821
            if hasattr(self._loop, "shutdown_default_executor"):
11✔
1822
                # asyncio in Python >= 3.9 or uvloop >= 0.15.0
1823
                self._loop.run_until_complete(self._loop.shutdown_default_executor())
10✔
1824
            elif isinstance(self._loop, asyncio.BaseEventLoop) and hasattr(
4✔
1825
                self._loop, "_default_executor"
1826
            ):
1827
                # asyncio in Python < 3.9
1828
                self._loop.run_until_complete(_shutdown_default_executor(self._loop))
4✔
1829
        finally:
1830
            asyncio.set_event_loop(None)
11✔
1831
            self._loop.close()
11✔
1832

1833
    def run_asyncgen_fixture(
11✔
1834
        self,
1835
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
1836
        kwargs: dict[str, Any],
1837
    ) -> Iterable[T_Retval]:
1838
        asyncgen = fixture_func(**kwargs)
11✔
1839
        fixturevalue: T_Retval = self._loop.run_until_complete(
11✔
1840
            self._call_in_runner_task(asyncgen.asend, None)
1841
        )
1842
        self._raise_async_exceptions()
11✔
1843

1844
        yield fixturevalue
11✔
1845

1846
        try:
11✔
1847
            self._loop.run_until_complete(
11✔
1848
                self._call_in_runner_task(asyncgen.asend, None)
1849
            )
1850
        except StopAsyncIteration:
11✔
1851
            self._raise_async_exceptions()
11✔
1852
        else:
1853
            self._loop.run_until_complete(asyncgen.aclose())
×
1854
            raise RuntimeError("Async generator fixture did not stop")
×
1855

1856
    def run_fixture(
11✔
1857
        self,
1858
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
1859
        kwargs: dict[str, Any],
1860
    ) -> T_Retval:
1861
        retval = self._loop.run_until_complete(
11✔
1862
            self._call_in_runner_task(fixture_func, **kwargs)
1863
        )
1864
        self._raise_async_exceptions()
11✔
1865
        return retval
11✔
1866

1867
    def run_test(
11✔
1868
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
1869
    ) -> None:
1870
        try:
11✔
1871
            self._loop.run_until_complete(
11✔
1872
                self._call_in_runner_task(test_func, **kwargs)
1873
            )
1874
        except Exception as exc:
11✔
1875
            self._exceptions.append(exc)
×
1876

1877
        self._raise_async_exceptions()
11✔
1878

1879

1880
class AsyncIOBackend(AsyncBackend):
11✔
1881
    @classmethod
11✔
1882
    def run(
8✔
1883
        cls,
1884
        func: Callable[..., Awaitable[T_Retval]],
1885
        args: tuple,
1886
        kwargs: dict[str, Any],
1887
        options: dict[str, Any],
1888
    ) -> T_Retval:
1889
        @wraps(func)
11✔
1890
        async def wrapper() -> T_Retval:
8✔
1891
            task = cast(asyncio.Task, current_task())
11✔
1892
            task_state = TaskState(None, get_callable_name(func), None)
11✔
1893
            _task_states[task] = task_state
11✔
1894
            if _native_task_names:
11✔
1895
                task.set_name(task_state.name)
8✔
1896

1897
            try:
11✔
1898
                return await func(*args)
11✔
1899
            finally:
1900
                del _task_states[task]
11✔
1901

1902
        debug = options.get("debug", False)
11✔
1903
        policy = options.get("policy", None)
11✔
1904
        use_uvloop = options.get("use_uvloop", False)
11✔
1905
        _maybe_set_event_loop_policy(policy, use_uvloop)
11✔
1906
        return native_run(wrapper(), debug=debug)
11✔
1907

1908
    @classmethod
11✔
1909
    def current_token(cls) -> object:
8✔
1910
        return get_running_loop()
11✔
1911

1912
    @classmethod
11✔
1913
    def current_time(cls) -> float:
8✔
1914
        return get_running_loop().time()
11✔
1915

1916
    @classmethod
11✔
1917
    def cancelled_exception_class(cls) -> type[BaseException]:
8✔
1918
        return CancelledError
11✔
1919

1920
    @classmethod
11✔
1921
    async def checkpoint(cls) -> None:
8✔
1922
        await sleep(0)
11✔
1923

1924
    @classmethod
11✔
1925
    async def checkpoint_if_cancelled(cls) -> None:
8✔
1926
        task = current_task()
11✔
1927
        if task is None:
11✔
1928
            return
×
1929

1930
        try:
11✔
1931
            cancel_scope = _task_states[task].cancel_scope
11✔
1932
        except KeyError:
11✔
1933
            return
11✔
1934

1935
        while cancel_scope:
11✔
1936
            if cancel_scope.cancel_called:
11✔
1937
                await sleep(0)
11✔
1938
            elif cancel_scope.shield:
11✔
1939
                break
10✔
1940
            else:
1941
                cancel_scope = cancel_scope._parent_scope
11✔
1942

1943
    @classmethod
11✔
1944
    async def cancel_shielded_checkpoint(cls) -> None:
8✔
1945
        with CancelScope(shield=True):
11✔
1946
            await sleep(0)
11✔
1947

1948
    @classmethod
11✔
1949
    async def sleep(cls, delay: float) -> None:
8✔
1950
        await sleep(delay)
11✔
1951

1952
    @classmethod
11✔
1953
    def create_cancel_scope(
11✔
1954
        cls, *, deadline: float = math.inf, shield: bool = False
1955
    ) -> CancelScope:
1956
        return CancelScope(deadline=deadline, shield=shield)
11✔
1957

1958
    @classmethod
11✔
1959
    def current_effective_deadline(cls) -> float:
8✔
1960
        try:
10✔
1961
            cancel_scope = _task_states[
10✔
1962
                current_task()  # type: ignore[index]
1963
            ].cancel_scope
1964
        except KeyError:
×
1965
            return math.inf
×
1966

1967
        deadline = math.inf
10✔
1968
        while cancel_scope:
10✔
1969
            deadline = min(deadline, cancel_scope.deadline)
10✔
1970
            if cancel_scope._cancel_called:
10✔
1971
                deadline = -math.inf
10✔
1972
                break
10✔
1973
            elif cancel_scope.shield:
10✔
1974
                break
10✔
1975
            else:
1976
                cancel_scope = cancel_scope._parent_scope
10✔
1977

1978
        return deadline
10✔
1979

1980
    @classmethod
11✔
1981
    def create_task_group(cls) -> abc.TaskGroup:
8✔
1982
        return TaskGroup()
11✔
1983

1984
    @classmethod
11✔
1985
    def create_event(cls) -> abc.Event:
8✔
1986
        return Event()
11✔
1987

1988
    @classmethod
11✔
1989
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
8✔
1990
        return CapacityLimiter(total_tokens)
10✔
1991

1992
    @classmethod
11✔
1993
    async def run_sync_in_worker_thread(
11✔
1994
        cls,
1995
        func: Callable[..., T_Retval],
1996
        args: tuple[Any, ...],
1997
        cancellable: bool = False,
1998
        limiter: abc.CapacityLimiter | None = None,
1999
    ) -> T_Retval:
2000
        await cls.checkpoint()
11✔
2001

2002
        # If this is the first run in this event loop thread, set up the necessary
2003
        # variables
2004
        try:
11✔
2005
            idle_workers = _threadpool_idle_workers.get()
11✔
2006
            workers = _threadpool_workers.get()
11✔
2007
        except LookupError:
11✔
2008
            idle_workers = deque()
11✔
2009
            workers = set()
11✔
2010
            _threadpool_idle_workers.set(idle_workers)
11✔
2011
            _threadpool_workers.set(workers)
11✔
2012

2013
        async with (limiter or cls.current_default_thread_limiter()):
11✔
2014
            with CancelScope(shield=not cancellable):
11✔
2015
                future: asyncio.Future = asyncio.Future()
11✔
2016
                root_task = find_root_task()
11✔
2017
                if not idle_workers:
11✔
2018
                    worker = WorkerThread(root_task, workers, idle_workers)
11✔
2019
                    worker.start()
11✔
2020
                    workers.add(worker)
11✔
2021
                    root_task.add_done_callback(worker.stop)
11✔
2022
                else:
2023
                    worker = idle_workers.pop()
11✔
2024

2025
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2026
                    # seconds or longer
2027
                    now = cls.current_time()
11✔
2028
                    while idle_workers:
11✔
2029
                        if (
10✔
2030
                            now - idle_workers[0].idle_since
2031
                            < WorkerThread.MAX_IDLE_TIME
2032
                        ):
2033
                            break
10✔
2034

2035
                        expired_worker = idle_workers.popleft()
×
2036
                        expired_worker.root_task.remove_done_callback(
×
2037
                            expired_worker.stop
2038
                        )
2039
                        expired_worker.stop()
×
2040

2041
                context = copy_context()
11✔
2042
                context.run(sniffio.current_async_library_cvar.set, None)
11✔
2043
                worker.queue.put_nowait((context, func, args, future))
11✔
2044
                return await future
11✔
2045

2046
    @classmethod
11✔
2047
    def run_async_from_thread(
8✔
2048
        cls,
2049
        func: Callable[..., Awaitable[T_Retval]],
2050
        args: tuple[Any, ...],
2051
        token: object,
2052
    ) -> T_Retval:
2053
        loop = cast(AbstractEventLoop, token)
11✔
2054
        context = copy_context()
11✔
2055
        context.run(sniffio.current_async_library_cvar.set, "asyncio")
11✔
2056
        f: concurrent.futures.Future[T_Retval] = context.run(
11✔
2057
            asyncio.run_coroutine_threadsafe, func(*args), loop
2058
        )
2059
        return f.result()
11✔
2060

2061
    @classmethod
11✔
2062
    def run_sync_from_thread(
8✔
2063
        cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
2064
    ) -> T_Retval:
2065
        @wraps(func)
11✔
2066
        def wrapper() -> None:
8✔
2067
            try:
11✔
2068
                sniffio.current_async_library_cvar.set("asyncio")
11✔
2069
                f.set_result(func(*args))
11✔
2070
            except BaseException as exc:
11✔
2071
                f.set_exception(exc)
11✔
2072
                if not isinstance(exc, Exception):
11✔
2073
                    raise
×
2074

2075
        f: concurrent.futures.Future[T_Retval] = Future()
11✔
2076
        loop = cast(AbstractEventLoop, token)
11✔
2077
        loop.call_soon_threadsafe(wrapper)
11✔
2078
        return f.result()
11✔
2079

2080
    @classmethod
11✔
2081
    def create_blocking_portal(cls) -> abc.BlockingPortal:
8✔
2082
        return BlockingPortal()
11✔
2083

2084
    @classmethod
11✔
2085
    async def open_process(
11✔
2086
        cls,
2087
        command: str | bytes | Sequence[str | bytes],
2088
        *,
2089
        shell: bool,
2090
        stdin: int | IO[Any] | None,
2091
        stdout: int | IO[Any] | None,
2092
        stderr: int | IO[Any] | None,
2093
        cwd: str | bytes | PathLike | None = None,
2094
        env: Mapping[str, str] | None = None,
2095
        start_new_session: bool = False,
2096
    ) -> Process:
2097
        await cls.checkpoint()
9✔
2098
        if shell:
9✔
2099
            process = await asyncio.create_subprocess_shell(
9✔
2100
                cast("str | bytes", command),
2101
                stdin=stdin,
2102
                stdout=stdout,
2103
                stderr=stderr,
2104
                cwd=cwd,
2105
                env=env,
2106
                start_new_session=start_new_session,
2107
            )
2108
        else:
2109
            process = await asyncio.create_subprocess_exec(
9✔
2110
                *command,
2111
                stdin=stdin,
2112
                stdout=stdout,
2113
                stderr=stderr,
2114
                cwd=cwd,
2115
                env=env,
2116
                start_new_session=start_new_session,
2117
            )
2118

2119
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
9✔
2120
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
9✔
2121
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
9✔
2122
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
9✔
2123

2124
    @classmethod
11✔
2125
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
8✔
2126
        kwargs: dict[str, Any] = (
9✔
2127
            {"name": "AnyIO process pool shutdown task"} if _native_task_names else {}
2128
        )
2129
        create_task(_shutdown_process_pool_on_exit(workers), **kwargs)
9✔
2130
        find_root_task().add_done_callback(
9✔
2131
            partial(_forcibly_shutdown_process_pool_on_exit, workers)
2132
        )
2133

2134
    @classmethod
11✔
2135
    async def connect_tcp(
11✔
2136
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2137
    ) -> abc.SocketStream:
2138
        transport, protocol = cast(
11✔
2139
            Tuple[asyncio.Transport, StreamProtocol],
2140
            await get_running_loop().create_connection(
2141
                StreamProtocol, host, port, local_addr=local_address
2142
            ),
2143
        )
2144
        transport.pause_reading()
11✔
2145
        return SocketStream(transport, protocol)
11✔
2146

2147
    @classmethod
11✔
2148
    async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
8✔
2149
        await cls.checkpoint()
8✔
2150
        loop = get_running_loop()
8✔
2151
        raw_socket = socket.socket(socket.AF_UNIX)
8✔
2152
        raw_socket.setblocking(False)
8✔
2153
        while True:
6✔
2154
            try:
8✔
2155
                raw_socket.connect(path)
8✔
2156
            except BlockingIOError:
8✔
2157
                f: asyncio.Future = asyncio.Future()
×
2158
                loop.add_writer(raw_socket, f.set_result, None)
×
2159
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2160
                await f
×
2161
            except BaseException:
8✔
2162
                raw_socket.close()
8✔
2163
                raise
8✔
2164
            else:
2165
                return UNIXSocketStream(raw_socket)
8✔
2166

2167
    @classmethod
11✔
2168
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
8✔
2169
        return TCPSocketListener(sock)
11✔
2170

2171
    @classmethod
11✔
2172
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
8✔
2173
        return UNIXSocketListener(sock)
8✔
2174

2175
    @classmethod
11✔
2176
    async def create_udp_socket(
8✔
2177
        cls,
2178
        family: AddressFamily,
2179
        local_address: IPSockAddrType | None,
2180
        remote_address: IPSockAddrType | None,
2181
        reuse_port: bool,
2182
    ) -> UDPSocket | ConnectedUDPSocket:
2183
        transport, protocol = await get_running_loop().create_datagram_endpoint(
10✔
2184
            DatagramProtocol,
2185
            local_addr=local_address,
2186
            remote_addr=remote_address,
2187
            family=family,
2188
            reuse_port=reuse_port,
2189
        )
2190
        if protocol.exception:
10✔
2191
            transport.close()
×
2192
            raise protocol.exception
×
2193

2194
        if not remote_address:
10✔
2195
            return UDPSocket(transport, protocol)
10✔
2196
        else:
2197
            return ConnectedUDPSocket(transport, protocol)
10✔
2198

2199
    @classmethod
11✔
2200
    async def create_unix_datagram_socket(  # type: ignore[override]
8✔
2201
        cls, raw_socket: socket.socket, remote_path: str | None
2202
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2203
        await cls.checkpoint()
8✔
2204
        loop = get_running_loop()
8✔
2205

2206
        if remote_path:
8✔
2207
            while True:
6✔
2208
                try:
8✔
2209
                    raw_socket.connect(remote_path)
8✔
2210
                except BlockingIOError:
×
2211
                    f: asyncio.Future = asyncio.Future()
×
2212
                    loop.add_writer(raw_socket, f.set_result, None)
×
2213
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2214
                    await f
×
2215
                except BaseException:
×
2216
                    raw_socket.close()
×
2217
                    raise
×
2218
                else:
2219
                    return ConnectedUNIXDatagramSocket(raw_socket)
8✔
2220
        else:
2221
            return UNIXDatagramSocket(raw_socket)
8✔
2222

2223
    @classmethod
11✔
2224
    async def getaddrinfo(
11✔
2225
        cls,
2226
        host: bytes | str | None,
2227
        port: str | int | None,
2228
        *,
2229
        family: int | AddressFamily = 0,
2230
        type: int | SocketKind = 0,
2231
        proto: int = 0,
2232
        flags: int = 0,
2233
    ) -> list[
2234
        tuple[
2235
            AddressFamily,
2236
            SocketKind,
2237
            int,
2238
            str,
2239
            tuple[str, int] | tuple[str, int, int, int],
2240
        ]
2241
    ]:
2242
        return await get_running_loop().getaddrinfo(
11✔
2243
            host, port, family=family, type=type, proto=proto, flags=flags
2244
        )
2245

2246
    @classmethod
11✔
2247
    async def getnameinfo(
11✔
2248
        cls, sockaddr: IPSockAddrType, flags: int = 0
2249
    ) -> tuple[str, str]:
2250
        return await get_running_loop().getnameinfo(sockaddr, flags)
10✔
2251

2252
    @classmethod
11✔
2253
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
8✔
2254
        await cls.checkpoint()
×
2255
        try:
×
2256
            read_events = _read_events.get()
×
2257
        except LookupError:
×
2258
            read_events = {}
×
2259
            _read_events.set(read_events)
×
2260

2261
        if read_events.get(sock):
×
2262
            raise BusyResourceError("reading from") from None
×
2263

2264
        loop = get_running_loop()
×
2265
        event = read_events[sock] = asyncio.Event()
×
2266
        loop.add_reader(sock, event.set)
×
2267
        try:
×
2268
            await event.wait()
×
2269
        finally:
2270
            if read_events.pop(sock, None) is not None:
×
2271
                loop.remove_reader(sock)
×
2272
                readable = True
×
2273
            else:
2274
                readable = False
×
2275

2276
        if not readable:
×
2277
            raise ClosedResourceError
×
2278

2279
    @classmethod
11✔
2280
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
8✔
2281
        await cls.checkpoint()
×
2282
        try:
×
2283
            write_events = _write_events.get()
×
2284
        except LookupError:
×
2285
            write_events = {}
×
2286
            _write_events.set(write_events)
×
2287

2288
        if write_events.get(sock):
×
2289
            raise BusyResourceError("writing to") from None
×
2290

2291
        loop = get_running_loop()
×
2292
        event = write_events[sock] = asyncio.Event()
×
2293
        loop.add_writer(sock.fileno(), event.set)
×
2294
        try:
×
2295
            await event.wait()
×
2296
        finally:
2297
            if write_events.pop(sock, None) is not None:
×
2298
                loop.remove_writer(sock)
×
2299
                writable = True
×
2300
            else:
2301
                writable = False
×
2302

2303
        if not writable:
×
2304
            raise ClosedResourceError
×
2305

2306
    @classmethod
11✔
2307
    def current_default_thread_limiter(cls) -> CapacityLimiter:
8✔
2308
        try:
11✔
2309
            return _default_thread_limiter.get()
11✔
2310
        except LookupError:
11✔
2311
            limiter = CapacityLimiter(40)
11✔
2312
            _default_thread_limiter.set(limiter)
11✔
2313
            return limiter
11✔
2314

2315
    @classmethod
11✔
2316
    def open_signal_receiver(
8✔
2317
        cls, *signals: Signals
2318
    ) -> ContextManager[AsyncIterator[Signals]]:
2319
        return _SignalReceiver(signals)
9✔
2320

2321
    @classmethod
11✔
2322
    def get_current_task(cls) -> TaskInfo:
8✔
2323
        return _create_task_info(current_task())  # type: ignore[arg-type]
11✔
2324

2325
    @classmethod
11✔
2326
    def get_running_tasks(cls) -> list[TaskInfo]:
8✔
2327
        return [_create_task_info(task) for task in all_tasks() if not task.done()]
11✔
2328

2329
    @classmethod
11✔
2330
    async def wait_all_tasks_blocked(cls) -> None:
8✔
2331
        await cls.checkpoint()
11✔
2332
        this_task = current_task()
11✔
2333
        while True:
8✔
2334
            for task in all_tasks():
11✔
2335
                if task is this_task:
11✔
2336
                    continue
11✔
2337

2338
                waiter = task._fut_waiter  # type: ignore[attr-defined]
11✔
2339
                if waiter is None or waiter.done():
11✔
2340
                    await sleep(0.1)
11✔
2341
                    break
11✔
2342
            else:
2343
                return
11✔
2344

2345
    @classmethod
11✔
2346
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
8✔
2347
        return TestRunner(**options)
11✔
2348

2349

2350
backend_class = AsyncIOBackend
11✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc