• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 4656540756

pending completion
4656540756

push

github

Alex Grönholm
Run CI against the latest PyPy

3912 of 4340 relevant lines covered (90.14%)

8.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.84
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import socket
10✔
8
import sys
10✔
9
import threading
10✔
10
from asyncio import (
10✔
11
    AbstractEventLoop,
12
    CancelledError,
13
    all_tasks,
14
    create_task,
15
    current_task,
16
    get_running_loop,
17
    sleep,
18
)
19
from asyncio import run as native_run
10✔
20
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
21
from collections import OrderedDict, deque
10✔
22
from collections.abc import AsyncIterator, Iterable
10✔
23
from concurrent.futures import Future
10✔
24
from contextvars import Context, copy_context
10✔
25
from dataclasses import dataclass
10✔
26
from functools import partial, wraps
10✔
27
from inspect import CORO_RUNNING, CORO_SUSPENDED, getcoroutinestate
10✔
28
from io import IOBase
10✔
29
from os import PathLike
10✔
30
from queue import Queue
10✔
31
from signal import Signals
10✔
32
from socket import AddressFamily, SocketKind
10✔
33
from threading import Thread
10✔
34
from types import TracebackType
10✔
35
from typing import (
10✔
36
    IO,
37
    Any,
38
    AsyncGenerator,
39
    Awaitable,
40
    Callable,
41
    Collection,
42
    ContextManager,
43
    Coroutine,
44
    Deque,
45
    Generator,
46
    Iterator,
47
    Mapping,
48
    Optional,
49
    Sequence,
50
    Tuple,
51
    TypeVar,
52
    cast,
53
)
54
from weakref import WeakKeyDictionary
10✔
55

56
import sniffio
10✔
57

58
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
10✔
59
from .._core._eventloop import claim_worker_thread
10✔
60
from .._core._exceptions import (
10✔
61
    BrokenResourceError,
62
    BusyResourceError,
63
    ClosedResourceError,
64
    EndOfStream,
65
    WouldBlock,
66
)
67
from .._core._sockets import convert_ipv6_sockaddr
10✔
68
from .._core._streams import create_memory_object_stream
10✔
69
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
10✔
70
from .._core._synchronization import Event as BaseEvent
10✔
71
from .._core._synchronization import ResourceGuard
10✔
72
from .._core._tasks import CancelScope as BaseCancelScope
10✔
73
from ..abc import (
10✔
74
    AsyncBackend,
75
    IPSockAddrType,
76
    SocketListener,
77
    UDPPacketType,
78
    UNIXDatagramPacketType,
79
)
80
from ..lowlevel import RunVar
10✔
81
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
82

83
if sys.version_info < (3, 11):
10✔
84
    from exceptiongroup import BaseExceptionGroup, ExceptionGroup
7✔
85

86
if sys.version_info >= (3, 8):
10✔
87

88
    def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
7✔
89
        return task.get_coro()
7✔
90

91
else:
92

93
    def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
3✔
94
        return task._coro
3✔
95

96

97
T_Retval = TypeVar("T_Retval")
10✔
98

99
# Check whether there is native support for task names in asyncio (3.8+)
100
_native_task_names = hasattr(asyncio.Task, "get_name")
10✔
101

102
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
103

104

105
def find_root_task() -> asyncio.Task:
10✔
106
    root_task = _root_task.get(None)
10✔
107
    if root_task is not None and not root_task.done():
10✔
108
        return root_task
10✔
109

110
    # Look for a task that has been started via run_until_complete()
111
    for task in all_tasks():
10✔
112
        if task._callbacks and not task.done():
10✔
113
            callbacks = [cb for cb, context in task._callbacks]
10✔
114
            for cb in callbacks:
10✔
115
                if (
10✔
116
                    cb is _run_until_complete_cb
117
                    or getattr(cb, "__module__", None) == "uvloop.loop"
118
                ):
119
                    _root_task.set(task)
10✔
120
                    return task
10✔
121

122
    # Look up the topmost task in the AnyIO task tree, if possible
123
    task = cast(asyncio.Task, current_task())
9✔
124
    state = _task_states.get(task)
9✔
125
    if state:
9✔
126
        cancel_scope = state.cancel_scope
9✔
127
        while cancel_scope and cancel_scope._parent_scope is not None:
9✔
128
            cancel_scope = cancel_scope._parent_scope
×
129

130
        if cancel_scope is not None:
9✔
131
            return cast(asyncio.Task, cancel_scope._host_task)
9✔
132

133
    return task
×
134

135

136
def get_callable_name(func: Callable) -> str:
10✔
137
    module = getattr(func, "__module__", None)
10✔
138
    qualname = getattr(func, "__qualname__", None)
10✔
139
    return ".".join([x for x in (module, qualname) if x])
10✔
140

141

142
#
143
# Event loop
144
#
145

146
_run_vars = (
10✔
147
    WeakKeyDictionary()
148
)  # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any]
149

150

151
def _task_started(task: asyncio.Task) -> bool:
10✔
152
    """Return ``True`` if the task has been started and has not finished."""
153
    coro = cast(Coroutine[Any, Any, Any], get_coro(task))
10✔
154
    try:
10✔
155
        return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
156
    except AttributeError:
×
157
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
158
        raise Exception(f"Cannot determine if task {task} has started or not")
×
159

160

161
def _maybe_set_event_loop_policy(
10✔
162
    policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool
163
) -> None:
164
    # On CPython, use uvloop when possible if no other policy has been given and if not
165
    # explicitly disabled
166
    if policy is None and use_uvloop and sys.implementation.name == "cpython":
10✔
167
        try:
×
168
            import uvloop
×
169
        except ImportError:
×
170
            pass
×
171
        else:
172
            # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier)
173
            if not hasattr(
×
174
                asyncio.AbstractEventLoop, "shutdown_default_executor"
175
            ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"):
176
                policy = uvloop.EventLoopPolicy()
×
177

178
    if policy is not None:
10✔
179
        asyncio.set_event_loop_policy(policy)
10✔
180

181

182
#
183
# Timeouts and cancellation
184
#
185

186

187
class CancelScope(BaseCancelScope):
10✔
188
    def __new__(
10✔
189
        cls, *, deadline: float = math.inf, shield: bool = False
190
    ) -> CancelScope:
191
        return object.__new__(cls)
10✔
192

193
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
194
        self._deadline = deadline
10✔
195
        self._shield = shield
10✔
196
        self._parent_scope: CancelScope | None = None
10✔
197
        self._cancel_called = False
10✔
198
        self._active = False
10✔
199
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
200
        self._cancel_handle: asyncio.Handle | None = None
10✔
201
        self._tasks: set[asyncio.Task] = set()
10✔
202
        self._host_task: asyncio.Task | None = None
10✔
203
        self._timeout_expired = False
10✔
204

205
    def __enter__(self) -> CancelScope:
10✔
206
        if self._active:
10✔
207
            raise RuntimeError(
×
208
                "Each CancelScope may only be used for a single 'with' block"
209
            )
210

211
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
212
        self._tasks.add(host_task)
10✔
213
        try:
10✔
214
            task_state = _task_states[host_task]
10✔
215
        except KeyError:
10✔
216
            task_name = host_task.get_name() if _native_task_names else None
10✔
217
            task_state = TaskState(None, task_name, self)
10✔
218
            _task_states[host_task] = task_state
10✔
219
        else:
220
            self._parent_scope = task_state.cancel_scope
10✔
221
            task_state.cancel_scope = self
10✔
222

223
        self._timeout()
10✔
224
        self._active = True
10✔
225

226
        # Start cancelling the host task if the scope was cancelled before entering
227
        if self._cancel_called:
10✔
228
            self._deliver_cancellation()
10✔
229

230
        return self
10✔
231

232
    def __exit__(
10✔
233
        self,
234
        exc_type: type[BaseException] | None,
235
        exc_val: BaseException | None,
236
        exc_tb: TracebackType | None,
237
    ) -> bool | None:
238
        if not self._active:
10✔
239
            raise RuntimeError("This cancel scope is not active")
9✔
240
        if current_task() is not self._host_task:
10✔
241
            raise RuntimeError(
9✔
242
                "Attempted to exit cancel scope in a different task than it was "
243
                "entered in"
244
            )
245

246
        assert self._host_task is not None
10✔
247
        host_task_state = _task_states.get(self._host_task)
10✔
248
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
249
            raise RuntimeError(
9✔
250
                "Attempted to exit a cancel scope that isn't the current tasks's "
251
                "current cancel scope"
252
            )
253

254
        self._active = False
10✔
255
        if self._timeout_handle:
10✔
256
            self._timeout_handle.cancel()
10✔
257
            self._timeout_handle = None
10✔
258

259
        self._tasks.remove(self._host_task)
10✔
260

261
        host_task_state.cancel_scope = self._parent_scope
10✔
262

263
        # Restart the cancellation effort in the farthest directly cancelled parent
264
        # scope if this one was shielded
265
        if self._shield:
10✔
266
            self._deliver_cancellation_to_parent()
10✔
267

268
        if exc_val is not None:
10✔
269
            exceptions = (
10✔
270
                exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
271
            )
272
            if all(isinstance(exc, CancelledError) for exc in exceptions):
10✔
273
                if self._timeout_expired:
10✔
274
                    return True
10✔
275
                elif not self._cancel_called:
10✔
276
                    # Task was cancelled natively
277
                    return None
10✔
278
                elif not self._parent_cancelled():
10✔
279
                    # This scope was directly cancelled
280
                    return True
10✔
281

282
        return None
10✔
283

284
    def _timeout(self) -> None:
10✔
285
        if self._deadline != math.inf:
10✔
286
            loop = get_running_loop()
10✔
287
            if loop.time() >= self._deadline:
10✔
288
                self._timeout_expired = True
10✔
289
                self.cancel()
10✔
290
            else:
291
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
292

293
    def _deliver_cancellation(self) -> None:
10✔
294
        """
295
        Deliver cancellation to directly contained tasks and nested cancel scopes.
296

297
        Schedule another run at the end if we still have tasks eligible for
298
        cancellation.
299
        """
300
        should_retry = False
10✔
301
        current = current_task()
10✔
302
        for task in self._tasks:
10✔
303
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
304
                continue
9✔
305

306
            # The task is eligible for cancellation if it has started and is not in a
307
            # cancel scope shielded from this one
308
            cancel_scope = _task_states[task].cancel_scope
10✔
309
            while cancel_scope is not self:
10✔
310
                if cancel_scope is None or cancel_scope._shield:
10✔
311
                    break
10✔
312
                else:
313
                    cancel_scope = cancel_scope._parent_scope
10✔
314
            else:
315
                should_retry = True
10✔
316
                if task is not current and (
10✔
317
                    task is self._host_task or _task_started(task)
318
                ):
319
                    task.cancel()
10✔
320

321
        # Schedule another callback if there are still tasks left
322
        if should_retry:
10✔
323
            self._cancel_handle = get_running_loop().call_soon(
10✔
324
                self._deliver_cancellation
325
            )
326
        else:
327
            self._cancel_handle = None
10✔
328

329
    def _deliver_cancellation_to_parent(self) -> None:
10✔
330
        """Start cancellation effort in the farthest directly cancelled parent scope"""
331
        scope = self._parent_scope
10✔
332
        scope_to_cancel: CancelScope | None = None
10✔
333
        while scope is not None:
10✔
334
            if scope._cancel_called and scope._cancel_handle is None:
10✔
335
                scope_to_cancel = scope
9✔
336

337
            # No point in looking beyond any shielded scope
338
            if scope._shield:
10✔
339
                break
9✔
340

341
            scope = scope._parent_scope
10✔
342

343
        if scope_to_cancel is not None:
10✔
344
            scope_to_cancel._deliver_cancellation()
9✔
345

346
    def _parent_cancelled(self) -> bool:
10✔
347
        # Check whether any parent has been cancelled
348
        cancel_scope = self._parent_scope
10✔
349
        while cancel_scope is not None and not cancel_scope._shield:
10✔
350
            if cancel_scope._cancel_called:
10✔
351
                return True
9✔
352
            else:
353
                cancel_scope = cancel_scope._parent_scope
10✔
354

355
        return False
10✔
356

357
    def cancel(self) -> None:
10✔
358
        if not self._cancel_called:
10✔
359
            if self._timeout_handle:
10✔
360
                self._timeout_handle.cancel()
10✔
361
                self._timeout_handle = None
10✔
362

363
            self._cancel_called = True
10✔
364
            if self._host_task is not None:
10✔
365
                self._deliver_cancellation()
10✔
366

367
    @property
10✔
368
    def deadline(self) -> float:
7✔
369
        return self._deadline
9✔
370

371
    @deadline.setter
10✔
372
    def deadline(self, value: float) -> None:
7✔
373
        self._deadline = float(value)
9✔
374
        if self._timeout_handle is not None:
9✔
375
            self._timeout_handle.cancel()
9✔
376
            self._timeout_handle = None
9✔
377

378
        if self._active and not self._cancel_called:
9✔
379
            self._timeout()
9✔
380

381
    @property
10✔
382
    def cancel_called(self) -> bool:
7✔
383
        return self._cancel_called
10✔
384

385
    @property
10✔
386
    def shield(self) -> bool:
7✔
387
        return self._shield
10✔
388

389
    @shield.setter
10✔
390
    def shield(self, value: bool) -> None:
7✔
391
        if self._shield != value:
9✔
392
            self._shield = value
9✔
393
            if not value:
9✔
394
                self._deliver_cancellation_to_parent()
9✔
395

396

397
#
398
# Task states
399
#
400

401

402
class TaskState:
10✔
403
    """
404
    Encapsulates auxiliary task information that cannot be added to the Task instance
405
    itself because there are no guarantees about its implementation.
406
    """
407

408
    __slots__ = "parent_id", "name", "cancel_scope"
10✔
409

410
    def __init__(
10✔
411
        self, parent_id: int | None, name: str | None, cancel_scope: CancelScope | None
412
    ):
413
        self.parent_id = parent_id
10✔
414
        self.name = name
10✔
415
        self.cancel_scope = cancel_scope
10✔
416

417

418
_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
10✔
419

420

421
#
422
# Task groups
423
#
424

425

426
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
427
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
428
        self._future = future
10✔
429
        self._parent_id = parent_id
10✔
430

431
    def started(self, value: object = None) -> None:
10✔
432
        try:
10✔
433
            self._future.set_result(value)
10✔
434
        except asyncio.InvalidStateError:
9✔
435
            raise RuntimeError(
9✔
436
                "called 'started' twice on the same task status"
437
            ) from None
438

439
        task = cast(asyncio.Task, current_task())
10✔
440
        _task_states[task].parent_id = self._parent_id
10✔
441

442

443
def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException:
10✔
444
    exceptions = list(excgroup.exceptions)
10✔
445
    modified = False
10✔
446
    for i, exc in enumerate(exceptions):
10✔
447
        if isinstance(exc, BaseExceptionGroup):
10✔
448
            new_exc = collapse_exception_group(exc)
9✔
449
            if new_exc is not exc:
9✔
450
                modified = True
9✔
451
                exceptions[i] = new_exc
9✔
452

453
    if len(exceptions) == 1:
10✔
454
        return exceptions[0]
10✔
455
    elif modified:
9✔
456
        return excgroup.derive(exceptions)
9✔
457
    else:
458
        return excgroup
9✔
459

460

461
def walk_exception_group(excgroup: BaseExceptionGroup) -> Iterator[BaseException]:
10✔
462
    for exc in excgroup.exceptions:
9✔
463
        if isinstance(exc, BaseExceptionGroup):
9✔
464
            yield from walk_exception_group(exc)
×
465
        else:
466
            yield exc
9✔
467

468

469
def is_anyio_cancelled_exc(exc: BaseException) -> bool:
10✔
470
    return isinstance(exc, CancelledError) and not exc.args
10✔
471

472

473
class TaskGroup(abc.TaskGroup):
10✔
474
    def __init__(self) -> None:
10✔
475
        self.cancel_scope: CancelScope = CancelScope()
10✔
476
        self._active = False
10✔
477
        self._exceptions: list[BaseException] = []
10✔
478

479
    async def __aenter__(self) -> TaskGroup:
10✔
480
        self.cancel_scope.__enter__()
10✔
481
        self._active = True
10✔
482
        return self
10✔
483

484
    async def __aexit__(
10✔
485
        self,
486
        exc_type: type[BaseException] | None,
487
        exc_val: BaseException | None,
488
        exc_tb: TracebackType | None,
489
    ) -> bool | None:
490
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
491
        if exc_val is not None:
10✔
492
            self.cancel_scope.cancel()
10✔
493
            self._exceptions.append(exc_val)
10✔
494

495
        while self.cancel_scope._tasks:
10✔
496
            try:
10✔
497
                await asyncio.wait(self.cancel_scope._tasks)
10✔
498
            except asyncio.CancelledError:
9✔
499
                self.cancel_scope.cancel()
9✔
500

501
        self._active = False
10✔
502
        if self._exceptions:
10✔
503
            exc: BaseException | None
504
            group = BaseExceptionGroup("multiple tasks failed", self._exceptions)
10✔
505
            if not self.cancel_scope._parent_cancelled():
10✔
506
                # If any exceptions other than AnyIO cancellation exceptions have been
507
                # received, raise those
508
                _, exc = group.split(is_anyio_cancelled_exc)
10✔
509
            elif all(is_anyio_cancelled_exc(e) for e in walk_exception_group(group)):
9✔
510
                # All tasks were cancelled by AnyIO
511
                exc = CancelledError()
9✔
512
            else:
513
                exc = group
9✔
514

515
            if isinstance(exc, BaseExceptionGroup):
10✔
516
                exc = collapse_exception_group(exc)
10✔
517

518
            if exc is not None and exc is not exc_val:
10✔
519
                raise exc
10✔
520

521
        return ignore_exception
10✔
522

523
    async def _run_wrapped_task(
10✔
524
        self, coro: Coroutine, task_status_future: asyncio.Future | None
525
    ) -> None:
526
        # This is the code path for Python 3.7 on which asyncio freaks out if a task
527
        # raises a BaseException.
528
        __traceback_hide__ = __tracebackhide__ = True  # noqa: F841
3✔
529
        task = cast(asyncio.Task, current_task())
3✔
530
        try:
3✔
531
            await coro
3✔
532
        except BaseException as exc:
3✔
533
            if task_status_future is None or task_status_future.done():
3✔
534
                self._exceptions.append(exc)
3✔
535
                self.cancel_scope.cancel()
3✔
536
            else:
537
                task_status_future.set_exception(exc)
3✔
538
        else:
539
            if task_status_future is not None and not task_status_future.done():
3✔
540
                task_status_future.set_exception(
3✔
541
                    RuntimeError("Child exited without calling task_status.started()")
542
                )
543
        finally:
544
            if task in self.cancel_scope._tasks:
3✔
545
                self.cancel_scope._tasks.remove(task)
3✔
546
                del _task_states[task]
3✔
547

548
    def _spawn(
10✔
549
        self,
550
        func: Callable[..., Awaitable[Any]],
551
        args: tuple,
552
        name: object,
553
        task_status_future: asyncio.Future | None = None,
554
    ) -> asyncio.Task:
555
        def task_done(_task: asyncio.Task) -> None:
10✔
556
            # This is the code path for Python 3.8+
557
            assert _task in self.cancel_scope._tasks
7✔
558
            self.cancel_scope._tasks.remove(_task)
7✔
559
            del _task_states[_task]
7✔
560

561
            try:
7✔
562
                exc = _task.exception()
7✔
563
            except CancelledError as e:
7✔
564
                while isinstance(e.__context__, CancelledError):
7✔
565
                    e = e.__context__
6✔
566

567
                exc = e
7✔
568

569
            if exc is not None:
7✔
570
                if task_status_future is None or task_status_future.done():
7✔
571
                    self._exceptions.append(exc)
7✔
572
                    self.cancel_scope.cancel()
7✔
573
                else:
574
                    task_status_future.set_exception(exc)
6✔
575
            elif task_status_future is not None and not task_status_future.done():
7✔
576
                task_status_future.set_exception(
6✔
577
                    RuntimeError("Child exited without calling task_status.started()")
578
                )
579

580
        if not self._active:
10✔
581
            raise RuntimeError(
9✔
582
                "This task group is not active; no new tasks can be started."
583
            )
584

585
        options = {}
10✔
586
        name = get_callable_name(func) if name is None else str(name)
10✔
587
        if _native_task_names:
10✔
588
            options["name"] = name
7✔
589

590
        kwargs = {}
10✔
591
        if task_status_future:
10✔
592
            parent_id = id(current_task())
10✔
593
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
594
                task_status_future, id(self.cancel_scope._host_task)
595
            )
596
        else:
597
            parent_id = id(self.cancel_scope._host_task)
10✔
598

599
        coro = func(*args, **kwargs)
10✔
600
        if not asyncio.iscoroutine(coro):
10✔
601
            raise TypeError(
9✔
602
                f"Expected an async function, but {func} appears to be synchronous"
603
            )
604

605
        foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame")
10✔
606
        if foreign_coro or sys.version_info < (3, 8):
10✔
607
            coro = self._run_wrapped_task(coro, task_status_future)
3✔
608

609
        task = create_task(coro, **options)
10✔
610
        if not foreign_coro and sys.version_info >= (3, 8):
10✔
611
            task.add_done_callback(task_done)
7✔
612

613
        # Make the spawned task inherit the task group's cancel scope
614
        _task_states[task] = TaskState(
10✔
615
            parent_id=parent_id, name=name, cancel_scope=self.cancel_scope
616
        )
617
        self.cancel_scope._tasks.add(task)
10✔
618
        return task
10✔
619

620
    def start_soon(
10✔
621
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
622
    ) -> None:
623
        self._spawn(func, args, name)
10✔
624

625
    async def start(
10✔
626
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
627
    ) -> None:
628
        future: asyncio.Future = asyncio.Future()
10✔
629
        task = self._spawn(func, args, name, future)
10✔
630

631
        # If the task raises an exception after sending a start value without a switch
632
        # point between, the task group is cancelled and this method never proceeds to
633
        # process the completed future. That's why we have to have a shielded cancel
634
        # scope here.
635
        with CancelScope(shield=True):
10✔
636
            try:
10✔
637
                return await future
10✔
638
            except CancelledError:
9✔
639
                task.cancel()
9✔
640
                raise
9✔
641

642

643
#
644
# Threads
645
#
646

647
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
10✔
648

649

650
class WorkerThread(Thread):
10✔
651
    MAX_IDLE_TIME = 10  # seconds
10✔
652

653
    def __init__(
10✔
654
        self,
655
        root_task: asyncio.Task,
656
        workers: set[WorkerThread],
657
        idle_workers: Deque[WorkerThread],
658
    ):
659
        super().__init__(name="AnyIO worker thread")
10✔
660
        self.root_task = root_task
10✔
661
        self.workers = workers
10✔
662
        self.idle_workers = idle_workers
10✔
663
        self.loop = root_task._loop
10✔
664
        self.queue: Queue[
10✔
665
            tuple[Context, Callable, tuple, asyncio.Future] | None
666
        ] = Queue(2)
667
        self.idle_since = AsyncIOBackend.current_time()
10✔
668
        self.stopping = False
10✔
669

670
    def _report_result(
10✔
671
        self, future: asyncio.Future, result: Any, exc: BaseException | None
672
    ) -> None:
673
        self.idle_since = AsyncIOBackend.current_time()
10✔
674
        if not self.stopping:
10✔
675
            self.idle_workers.append(self)
10✔
676

677
        if not future.cancelled():
10✔
678
            if exc is not None:
10✔
679
                if isinstance(exc, StopIteration):
10✔
680
                    new_exc = RuntimeError("coroutine raised StopIteration")
9✔
681
                    new_exc.__cause__ = exc
9✔
682
                    exc = new_exc
9✔
683

684
                future.set_exception(exc)
10✔
685
            else:
686
                future.set_result(result)
10✔
687

688
    def run(self) -> None:
10✔
689
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
690
            while True:
7✔
691
                item = self.queue.get()
10✔
692
                if item is None:
10✔
693
                    # Shutdown command received
694
                    return
10✔
695

696
                context, func, args, future = item
10✔
697
                if not future.cancelled():
10✔
698
                    result = None
10✔
699
                    exception: BaseException | None = None
10✔
700
                    try:
10✔
701
                        result = context.run(func, *args)
10✔
702
                    except BaseException as exc:
10✔
703
                        exception = exc
10✔
704

705
                    if not self.loop.is_closed():
10✔
706
                        self.loop.call_soon_threadsafe(
10✔
707
                            self._report_result, future, result, exception
708
                        )
709

710
                self.queue.task_done()
10✔
711

712
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
713
        self.stopping = True
10✔
714
        self.queue.put_nowait(None)
10✔
715
        self.workers.discard(self)
10✔
716
        try:
10✔
717
            self.idle_workers.remove(self)
10✔
718
        except ValueError:
9✔
719
            pass
9✔
720

721

722
_threadpool_idle_workers: RunVar[Deque[WorkerThread]] = RunVar(
10✔
723
    "_threadpool_idle_workers"
724
)
725
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
726

727

728
class BlockingPortal(abc.BlockingPortal):
10✔
729
    def __new__(cls) -> BlockingPortal:
10✔
730
        return object.__new__(cls)
10✔
731

732
    def __init__(self) -> None:
10✔
733
        super().__init__()
10✔
734
        self._loop = get_running_loop()
10✔
735

736
    def _spawn_task_from_thread(
10✔
737
        self,
738
        func: Callable,
739
        args: tuple[Any, ...],
740
        kwargs: dict[str, Any],
741
        name: object,
742
        future: Future,
743
    ) -> None:
744
        AsyncIOBackend.run_sync_from_thread(
10✔
745
            partial(self._task_group.start_soon, name=name),
746
            (self._call_func, func, args, kwargs, future),
747
            self._loop,
748
        )
749

750

751
#
752
# Subprocesses
753
#
754

755

756
@dataclass(eq=False)
10✔
757
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
758
    _stream: asyncio.StreamReader
10✔
759

760
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
761
        data = await self._stream.read(max_bytes)
8✔
762
        if data:
8✔
763
            return data
8✔
764
        else:
765
            raise EndOfStream
8✔
766

767
    async def aclose(self) -> None:
10✔
768
        self._stream.feed_eof()
8✔
769

770

771
@dataclass(eq=False)
10✔
772
class StreamWriterWrapper(abc.ByteSendStream):
10✔
773
    _stream: asyncio.StreamWriter
10✔
774

775
    async def send(self, item: bytes) -> None:
10✔
776
        self._stream.write(item)
8✔
777
        await self._stream.drain()
8✔
778

779
    async def aclose(self) -> None:
10✔
780
        self._stream.close()
8✔
781

782

783
@dataclass(eq=False)
10✔
784
class Process(abc.Process):
10✔
785
    _process: asyncio.subprocess.Process
10✔
786
    _stdin: StreamWriterWrapper | None
10✔
787
    _stdout: StreamReaderWrapper | None
10✔
788
    _stderr: StreamReaderWrapper | None
10✔
789

790
    async def aclose(self) -> None:
10✔
791
        if self._stdin:
8✔
792
            await self._stdin.aclose()
8✔
793
        if self._stdout:
8✔
794
            await self._stdout.aclose()
8✔
795
        if self._stderr:
8✔
796
            await self._stderr.aclose()
8✔
797

798
        await self.wait()
8✔
799

800
    async def wait(self) -> int:
10✔
801
        return await self._process.wait()
8✔
802

803
    def terminate(self) -> None:
10✔
804
        self._process.terminate()
7✔
805

806
    def kill(self) -> None:
10✔
807
        self._process.kill()
8✔
808

809
    def send_signal(self, signal: int) -> None:
10✔
810
        self._process.send_signal(signal)
×
811

812
    @property
10✔
813
    def pid(self) -> int:
7✔
814
        return self._process.pid
×
815

816
    @property
10✔
817
    def returncode(self) -> int | None:
7✔
818
        return self._process.returncode
8✔
819

820
    @property
10✔
821
    def stdin(self) -> abc.ByteSendStream | None:
7✔
822
        return self._stdin
8✔
823

824
    @property
10✔
825
    def stdout(self) -> abc.ByteReceiveStream | None:
7✔
826
        return self._stdout
8✔
827

828
    @property
10✔
829
    def stderr(self) -> abc.ByteReceiveStream | None:
7✔
830
        return self._stderr
8✔
831

832

833
def _forcibly_shutdown_process_pool_on_exit(
10✔
834
    workers: set[Process], _task: object
835
) -> None:
836
    """
837
    Forcibly shuts down worker processes belonging to this event loop."""
838
    child_watcher: asyncio.AbstractChildWatcher | None
839
    try:
8✔
840
        child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
8✔
841
    except NotImplementedError:
8✔
842
        child_watcher = None
8✔
843

844
    # Close as much as possible (w/o async/await) to avoid warnings
845
    for process in workers:
8✔
846
        if process.returncode is None:
8✔
847
            continue
8✔
848

849
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
850
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
851
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
852
        process.kill()
×
853
        if child_watcher:
×
854
            child_watcher.remove_child_handler(process.pid)
×
855

856

857
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
858
    """
859
    Shuts down worker processes belonging to this event loop.
860

861
    NOTE: this only works when the event loop was started using asyncio.run() or
862
    anyio.run().
863

864
    """
865
    process: abc.Process
866
    try:
8✔
867
        await sleep(math.inf)
8✔
868
    except asyncio.CancelledError:
8✔
869
        for process in workers:
8✔
870
            if process.returncode is None:
8✔
871
                process.kill()
8✔
872

873
        for process in workers:
8✔
874
            await process.aclose()
8✔
875

876

877
#
878
# Sockets and networking
879
#
880

881

882
class StreamProtocol(asyncio.Protocol):
10✔
883
    read_queue: Deque[bytes]
10✔
884
    read_event: asyncio.Event
10✔
885
    write_event: asyncio.Event
10✔
886
    exception: Exception | None = None
10✔
887

888
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
889
        self.read_queue = deque()
10✔
890
        self.read_event = asyncio.Event()
10✔
891
        self.write_event = asyncio.Event()
10✔
892
        self.write_event.set()
10✔
893
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
894

895
    def connection_lost(self, exc: Exception | None) -> None:
10✔
896
        if exc:
10✔
897
            self.exception = BrokenResourceError()
10✔
898
            self.exception.__cause__ = exc
10✔
899

900
        self.read_event.set()
10✔
901
        self.write_event.set()
10✔
902

903
    def data_received(self, data: bytes) -> None:
10✔
904
        self.read_queue.append(data)
10✔
905
        self.read_event.set()
10✔
906

907
    def eof_received(self) -> bool | None:
10✔
908
        self.read_event.set()
10✔
909
        return True
10✔
910

911
    def pause_writing(self) -> None:
10✔
912
        self.write_event = asyncio.Event()
10✔
913

914
    def resume_writing(self) -> None:
10✔
915
        self.write_event.set()
1✔
916

917

918
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
919
    read_queue: Deque[tuple[bytes, IPSockAddrType]]
10✔
920
    read_event: asyncio.Event
10✔
921
    write_event: asyncio.Event
10✔
922
    exception: Exception | None = None
10✔
923

924
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
925
        self.read_queue = deque(maxlen=100)  # arbitrary value
9✔
926
        self.read_event = asyncio.Event()
9✔
927
        self.write_event = asyncio.Event()
9✔
928
        self.write_event.set()
9✔
929

930
    def connection_lost(self, exc: Exception | None) -> None:
10✔
931
        self.read_event.set()
9✔
932
        self.write_event.set()
9✔
933

934
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
935
        addr = convert_ipv6_sockaddr(addr)
9✔
936
        self.read_queue.append((data, addr))
9✔
937
        self.read_event.set()
9✔
938

939
    def error_received(self, exc: Exception) -> None:
10✔
940
        self.exception = exc
×
941

942
    def pause_writing(self) -> None:
10✔
943
        self.write_event.clear()
×
944

945
    def resume_writing(self) -> None:
10✔
946
        self.write_event.set()
×
947

948

949
class SocketStream(abc.SocketStream):
10✔
950
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
951
        self._transport = transport
10✔
952
        self._protocol = protocol
10✔
953
        self._receive_guard = ResourceGuard("reading from")
10✔
954
        self._send_guard = ResourceGuard("writing to")
10✔
955
        self._closed = False
10✔
956

957
    @property
10✔
958
    def _raw_socket(self) -> socket.socket:
7✔
959
        return self._transport.get_extra_info("socket")
10✔
960

961
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
962
        with self._receive_guard:
10✔
963
            await AsyncIOBackend.checkpoint()
10✔
964

965
            if (
10✔
966
                not self._protocol.read_event.is_set()
967
                and not self._transport.is_closing()
968
            ):
969
                self._transport.resume_reading()
10✔
970
                await self._protocol.read_event.wait()
10✔
971
                self._transport.pause_reading()
10✔
972

973
            try:
10✔
974
                chunk = self._protocol.read_queue.popleft()
10✔
975
            except IndexError:
10✔
976
                if self._closed:
10✔
977
                    raise ClosedResourceError from None
10✔
978
                elif self._protocol.exception:
10✔
979
                    raise self._protocol.exception
10✔
980
                else:
981
                    raise EndOfStream from None
10✔
982

983
            if len(chunk) > max_bytes:
10✔
984
                # Split the oversized chunk
985
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
9✔
986
                self._protocol.read_queue.appendleft(leftover)
9✔
987

988
            # If the read queue is empty, clear the flag so that the next call will
989
            # block until data is available
990
            if not self._protocol.read_queue:
10✔
991
                self._protocol.read_event.clear()
10✔
992

993
        return chunk
10✔
994

995
    async def send(self, item: bytes) -> None:
10✔
996
        with self._send_guard:
10✔
997
            await AsyncIOBackend.checkpoint()
10✔
998

999
            if self._closed:
10✔
1000
                raise ClosedResourceError
10✔
1001
            elif self._protocol.exception is not None:
10✔
1002
                raise self._protocol.exception
10✔
1003

1004
            try:
10✔
1005
                self._transport.write(item)
10✔
1006
            except RuntimeError as exc:
×
1007
                if self._transport.is_closing():
×
1008
                    raise BrokenResourceError from exc
×
1009
                else:
1010
                    raise
×
1011

1012
            await self._protocol.write_event.wait()
10✔
1013

1014
    async def send_eof(self) -> None:
10✔
1015
        try:
10✔
1016
            self._transport.write_eof()
10✔
1017
        except OSError:
×
1018
            pass
×
1019

1020
    async def aclose(self) -> None:
10✔
1021
        if not self._transport.is_closing():
10✔
1022
            self._closed = True
10✔
1023
            try:
10✔
1024
                self._transport.write_eof()
10✔
1025
            except OSError:
×
1026
                pass
×
1027

1028
            self._transport.close()
10✔
1029
            await sleep(0)
10✔
1030
            self._transport.abort()
10✔
1031

1032

1033
class _RawSocketMixin:
10✔
1034
    _receive_future: asyncio.Future | None = None
10✔
1035
    _send_future: asyncio.Future | None = None
10✔
1036
    _closing = False
10✔
1037

1038
    def __init__(self, raw_socket: socket.socket):
10✔
1039
        self.__raw_socket = raw_socket
7✔
1040
        self._receive_guard = ResourceGuard("reading from")
7✔
1041
        self._send_guard = ResourceGuard("writing to")
7✔
1042

1043
    @property
10✔
1044
    def _raw_socket(self) -> socket.socket:
7✔
1045
        return self.__raw_socket
7✔
1046

1047
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1048
        def callback(f: object) -> None:
7✔
1049
            del self._receive_future
7✔
1050
            loop.remove_reader(self.__raw_socket)
7✔
1051

1052
        f = self._receive_future = asyncio.Future()
7✔
1053
        loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1054
        f.add_done_callback(callback)
7✔
1055
        return f
7✔
1056

1057
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1058
        def callback(f: object) -> None:
7✔
1059
            del self._send_future
7✔
1060
            loop.remove_writer(self.__raw_socket)
7✔
1061

1062
        f = self._send_future = asyncio.Future()
7✔
1063
        loop.add_writer(self.__raw_socket, f.set_result, None)
7✔
1064
        f.add_done_callback(callback)
7✔
1065
        return f
7✔
1066

1067
    async def aclose(self) -> None:
10✔
1068
        if not self._closing:
7✔
1069
            self._closing = True
7✔
1070
            if self.__raw_socket.fileno() != -1:
7✔
1071
                self.__raw_socket.close()
7✔
1072

1073
            if self._receive_future:
7✔
1074
                self._receive_future.set_result(None)
7✔
1075
            if self._send_future:
7✔
1076
                self._send_future.set_result(None)
×
1077

1078

1079
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1080
    async def send_eof(self) -> None:
10✔
1081
        with self._send_guard:
7✔
1082
            self._raw_socket.shutdown(socket.SHUT_WR)
7✔
1083

1084
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1085
        loop = get_running_loop()
7✔
1086
        await AsyncIOBackend.checkpoint()
7✔
1087
        with self._receive_guard:
7✔
1088
            while True:
5✔
1089
                try:
7✔
1090
                    data = self._raw_socket.recv(max_bytes)
7✔
1091
                except BlockingIOError:
7✔
1092
                    await self._wait_until_readable(loop)
7✔
1093
                except OSError as exc:
7✔
1094
                    if self._closing:
7✔
1095
                        raise ClosedResourceError from None
7✔
1096
                    else:
1097
                        raise BrokenResourceError from exc
1✔
1098
                else:
1099
                    if not data:
7✔
1100
                        raise EndOfStream
7✔
1101

1102
                    return data
7✔
1103

1104
    async def send(self, item: bytes) -> None:
10✔
1105
        loop = get_running_loop()
7✔
1106
        await AsyncIOBackend.checkpoint()
7✔
1107
        with self._send_guard:
7✔
1108
            view = memoryview(item)
7✔
1109
            while view:
7✔
1110
                try:
7✔
1111
                    bytes_sent = self._raw_socket.send(item)
7✔
1112
                except BlockingIOError:
7✔
1113
                    await self._wait_until_writable(loop)
7✔
1114
                except OSError as exc:
7✔
1115
                    if self._closing:
7✔
1116
                        raise ClosedResourceError from None
7✔
1117
                    else:
1118
                        raise BrokenResourceError from exc
1✔
1119
                else:
1120
                    view = view[bytes_sent:]
7✔
1121

1122
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1123
        if not isinstance(msglen, int) or msglen < 0:
7✔
1124
            raise ValueError("msglen must be a non-negative integer")
7✔
1125
        if not isinstance(maxfds, int) or maxfds < 1:
7✔
1126
            raise ValueError("maxfds must be a positive integer")
7✔
1127

1128
        loop = get_running_loop()
7✔
1129
        fds = array.array("i")
7✔
1130
        await AsyncIOBackend.checkpoint()
7✔
1131
        with self._receive_guard:
7✔
1132
            while True:
5✔
1133
                try:
7✔
1134
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
7✔
1135
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1136
                    )
1137
                except BlockingIOError:
7✔
1138
                    await self._wait_until_readable(loop)
7✔
1139
                except OSError as exc:
×
1140
                    if self._closing:
×
1141
                        raise ClosedResourceError from None
×
1142
                    else:
1143
                        raise BrokenResourceError from exc
×
1144
                else:
1145
                    if not message and not ancdata:
7✔
1146
                        raise EndOfStream
×
1147

1148
                    break
5✔
1149

1150
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
7✔
1151
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
7✔
1152
                raise RuntimeError(
×
1153
                    f"Received unexpected ancillary data; message = {message!r}, "
1154
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1155
                )
1156

1157
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
7✔
1158

1159
        return message, list(fds)
7✔
1160

1161
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1162
        if not message:
7✔
1163
            raise ValueError("message must not be empty")
7✔
1164
        if not fds:
7✔
1165
            raise ValueError("fds must not be empty")
7✔
1166

1167
        loop = get_running_loop()
7✔
1168
        filenos: list[int] = []
7✔
1169
        for fd in fds:
7✔
1170
            if isinstance(fd, int):
7✔
1171
                filenos.append(fd)
×
1172
            elif isinstance(fd, IOBase):
7✔
1173
                filenos.append(fd.fileno())
7✔
1174

1175
        fdarray = array.array("i", filenos)
7✔
1176
        await AsyncIOBackend.checkpoint()
7✔
1177
        with self._send_guard:
7✔
1178
            while True:
5✔
1179
                try:
7✔
1180
                    # The ignore can be removed after mypy picks up
1181
                    # https://github.com/python/typeshed/pull/5545
1182
                    self._raw_socket.sendmsg(
7✔
1183
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1184
                    )
1185
                    break
7✔
1186
                except BlockingIOError:
×
1187
                    await self._wait_until_writable(loop)
×
1188
                except OSError as exc:
×
1189
                    if self._closing:
×
1190
                        raise ClosedResourceError from None
×
1191
                    else:
1192
                        raise BrokenResourceError from exc
×
1193

1194

1195
class TCPSocketListener(abc.SocketListener):
10✔
1196
    _accept_scope: CancelScope | None = None
10✔
1197
    _closed = False
10✔
1198

1199
    def __init__(self, raw_socket: socket.socket):
10✔
1200
        self.__raw_socket = raw_socket
10✔
1201
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1202
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1203

1204
    @property
10✔
1205
    def _raw_socket(self) -> socket.socket:
7✔
1206
        return self.__raw_socket
10✔
1207

1208
    async def accept(self) -> abc.SocketStream:
10✔
1209
        if self._closed:
10✔
1210
            raise ClosedResourceError
10✔
1211

1212
        with self._accept_guard:
10✔
1213
            await AsyncIOBackend.checkpoint()
10✔
1214
            with CancelScope() as self._accept_scope:
10✔
1215
                try:
10✔
1216
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1217
                except asyncio.CancelledError:
9✔
1218
                    # Workaround for https://bugs.python.org/issue41317
1219
                    try:
9✔
1220
                        self._loop.remove_reader(self._raw_socket)
9✔
1221
                    except (ValueError, NotImplementedError):
1✔
1222
                        pass
1✔
1223

1224
                    if self._closed:
9✔
1225
                        raise ClosedResourceError from None
9✔
1226

1227
                    raise
9✔
1228
                finally:
1229
                    self._accept_scope = None
10✔
1230

1231
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1232
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1233
            StreamProtocol, client_sock
1234
        )
1235
        return SocketStream(transport, protocol)
10✔
1236

1237
    async def aclose(self) -> None:
10✔
1238
        if self._closed:
10✔
1239
            return
10✔
1240

1241
        self._closed = True
10✔
1242
        if self._accept_scope:
10✔
1243
            # Workaround for https://bugs.python.org/issue41317
1244
            try:
10✔
1245
                self._loop.remove_reader(self._raw_socket)
10✔
1246
            except (ValueError, NotImplementedError):
1✔
1247
                pass
1✔
1248

1249
            self._accept_scope.cancel()
9✔
1250
            await sleep(0)
9✔
1251

1252
        self._raw_socket.close()
10✔
1253

1254

1255
class UNIXSocketListener(abc.SocketListener):
10✔
1256
    def __init__(self, raw_socket: socket.socket):
10✔
1257
        self.__raw_socket = raw_socket
7✔
1258
        self._loop = get_running_loop()
7✔
1259
        self._accept_guard = ResourceGuard("accepting connections from")
7✔
1260
        self._closed = False
7✔
1261

1262
    async def accept(self) -> abc.SocketStream:
10✔
1263
        await AsyncIOBackend.checkpoint()
7✔
1264
        with self._accept_guard:
7✔
1265
            while True:
5✔
1266
                try:
7✔
1267
                    client_sock, _ = self.__raw_socket.accept()
7✔
1268
                    client_sock.setblocking(False)
7✔
1269
                    return UNIXSocketStream(client_sock)
7✔
1270
                except BlockingIOError:
7✔
1271
                    f: asyncio.Future = asyncio.Future()
7✔
1272
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1273
                    f.add_done_callback(
7✔
1274
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1275
                    )
1276
                    await f
7✔
1277
                except OSError as exc:
×
1278
                    if self._closed:
×
1279
                        raise ClosedResourceError from None
×
1280
                    else:
1281
                        raise BrokenResourceError from exc
1✔
1282

1283
    async def aclose(self) -> None:
10✔
1284
        self._closed = True
7✔
1285
        self.__raw_socket.close()
7✔
1286

1287
    @property
10✔
1288
    def _raw_socket(self) -> socket.socket:
7✔
1289
        return self.__raw_socket
7✔
1290

1291

1292
class UDPSocket(abc.UDPSocket):
10✔
1293
    def __init__(
10✔
1294
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1295
    ):
1296
        self._transport = transport
9✔
1297
        self._protocol = protocol
9✔
1298
        self._receive_guard = ResourceGuard("reading from")
9✔
1299
        self._send_guard = ResourceGuard("writing to")
9✔
1300
        self._closed = False
9✔
1301

1302
    @property
10✔
1303
    def _raw_socket(self) -> socket.socket:
7✔
1304
        return self._transport.get_extra_info("socket")
9✔
1305

1306
    async def aclose(self) -> None:
10✔
1307
        if not self._transport.is_closing():
9✔
1308
            self._closed = True
9✔
1309
            self._transport.close()
9✔
1310

1311
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1312
        with self._receive_guard:
9✔
1313
            await AsyncIOBackend.checkpoint()
9✔
1314

1315
            # If the buffer is empty, ask for more data
1316
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1317
                self._protocol.read_event.clear()
9✔
1318
                await self._protocol.read_event.wait()
9✔
1319

1320
            try:
9✔
1321
                return self._protocol.read_queue.popleft()
9✔
1322
            except IndexError:
9✔
1323
                if self._closed:
9✔
1324
                    raise ClosedResourceError from None
9✔
1325
                else:
1326
                    raise BrokenResourceError from None
1✔
1327

1328
    async def send(self, item: UDPPacketType) -> None:
10✔
1329
        with self._send_guard:
9✔
1330
            await AsyncIOBackend.checkpoint()
9✔
1331
            await self._protocol.write_event.wait()
9✔
1332
            if self._closed:
9✔
1333
                raise ClosedResourceError
9✔
1334
            elif self._transport.is_closing():
9✔
1335
                raise BrokenResourceError
×
1336
            else:
1337
                self._transport.sendto(*item)
9✔
1338

1339

1340
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1341
    def __init__(
10✔
1342
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1343
    ):
1344
        self._transport = transport
9✔
1345
        self._protocol = protocol
9✔
1346
        self._receive_guard = ResourceGuard("reading from")
9✔
1347
        self._send_guard = ResourceGuard("writing to")
9✔
1348
        self._closed = False
9✔
1349

1350
    @property
10✔
1351
    def _raw_socket(self) -> socket.socket:
7✔
1352
        return self._transport.get_extra_info("socket")
9✔
1353

1354
    async def aclose(self) -> None:
10✔
1355
        if not self._transport.is_closing():
9✔
1356
            self._closed = True
9✔
1357
            self._transport.close()
9✔
1358

1359
    async def receive(self) -> bytes:
10✔
1360
        with self._receive_guard:
9✔
1361
            await AsyncIOBackend.checkpoint()
9✔
1362

1363
            # If the buffer is empty, ask for more data
1364
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1365
                self._protocol.read_event.clear()
9✔
1366
                await self._protocol.read_event.wait()
9✔
1367

1368
            try:
9✔
1369
                packet = self._protocol.read_queue.popleft()
9✔
1370
            except IndexError:
9✔
1371
                if self._closed:
9✔
1372
                    raise ClosedResourceError from None
9✔
1373
                else:
1374
                    raise BrokenResourceError from None
×
1375

1376
            return packet[0]
9✔
1377

1378
    async def send(self, item: bytes) -> None:
10✔
1379
        with self._send_guard:
9✔
1380
            await AsyncIOBackend.checkpoint()
9✔
1381
            await self._protocol.write_event.wait()
9✔
1382
            if self._closed:
9✔
1383
                raise ClosedResourceError
9✔
1384
            elif self._transport.is_closing():
9✔
1385
                raise BrokenResourceError
×
1386
            else:
1387
                self._transport.sendto(item)
9✔
1388

1389

1390
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1391
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1392
        loop = get_running_loop()
7✔
1393
        await AsyncIOBackend.checkpoint()
7✔
1394
        with self._receive_guard:
7✔
1395
            while True:
5✔
1396
                try:
7✔
1397
                    data = self._raw_socket.recvfrom(65536)
7✔
1398
                except BlockingIOError:
7✔
1399
                    await self._wait_until_readable(loop)
7✔
1400
                except OSError as exc:
7✔
1401
                    if self._closing:
7✔
1402
                        raise ClosedResourceError from None
7✔
1403
                    else:
1404
                        raise BrokenResourceError from exc
1✔
1405
                else:
1406
                    return data
7✔
1407

1408
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1409
        loop = get_running_loop()
7✔
1410
        await AsyncIOBackend.checkpoint()
7✔
1411
        with self._send_guard:
7✔
1412
            while True:
5✔
1413
                try:
7✔
1414
                    self._raw_socket.sendto(*item)
7✔
1415
                except BlockingIOError:
7✔
1416
                    await self._wait_until_writable(loop)
×
1417
                except OSError as exc:
7✔
1418
                    if self._closing:
7✔
1419
                        raise ClosedResourceError from None
7✔
1420
                    else:
1421
                        raise BrokenResourceError from exc
1✔
1422
                else:
1423
                    return
7✔
1424

1425

1426
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1427
    async def receive(self) -> bytes:
10✔
1428
        loop = get_running_loop()
7✔
1429
        await AsyncIOBackend.checkpoint()
7✔
1430
        with self._receive_guard:
7✔
1431
            while True:
5✔
1432
                try:
7✔
1433
                    data = self._raw_socket.recv(65536)
7✔
1434
                except BlockingIOError:
7✔
1435
                    await self._wait_until_readable(loop)
7✔
1436
                except OSError as exc:
7✔
1437
                    if self._closing:
7✔
1438
                        raise ClosedResourceError from None
7✔
1439
                    else:
1440
                        raise BrokenResourceError from exc
1✔
1441
                else:
1442
                    return data
7✔
1443

1444
    async def send(self, item: bytes) -> None:
10✔
1445
        loop = get_running_loop()
7✔
1446
        await AsyncIOBackend.checkpoint()
7✔
1447
        with self._send_guard:
7✔
1448
            while True:
5✔
1449
                try:
7✔
1450
                    self._raw_socket.send(item)
7✔
1451
                except BlockingIOError:
7✔
1452
                    await self._wait_until_writable(loop)
×
1453
                except OSError as exc:
7✔
1454
                    if self._closing:
7✔
1455
                        raise ClosedResourceError from None
7✔
1456
                    else:
1457
                        raise BrokenResourceError from exc
1✔
1458
                else:
1459
                    return
7✔
1460

1461

1462
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1463
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1464

1465

1466
#
1467
# Synchronization
1468
#
1469

1470

1471
class Event(BaseEvent):
10✔
1472
    def __new__(cls) -> Event:
10✔
1473
        return object.__new__(cls)
10✔
1474

1475
    def __init__(self) -> None:
10✔
1476
        self._event = asyncio.Event()
10✔
1477

1478
    def set(self) -> None:
10✔
1479
        self._event.set()
10✔
1480

1481
    def is_set(self) -> bool:
10✔
1482
        return self._event.is_set()
9✔
1483

1484
    async def wait(self) -> None:
10✔
1485
        if await self._event.wait():
10✔
1486
            await AsyncIOBackend.checkpoint()
10✔
1487

1488
    def statistics(self) -> EventStatistics:
10✔
1489
        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
9✔
1490

1491

1492
class CapacityLimiter(BaseCapacityLimiter):
10✔
1493
    _total_tokens: float = 0
10✔
1494

1495
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1496
        return object.__new__(cls)
10✔
1497

1498
    def __init__(self, total_tokens: float):
10✔
1499
        self._borrowers: set[Any] = set()
10✔
1500
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1501
        self.total_tokens = total_tokens
10✔
1502

1503
    async def __aenter__(self) -> None:
10✔
1504
        await self.acquire()
10✔
1505

1506
    async def __aexit__(
10✔
1507
        self,
1508
        exc_type: type[BaseException] | None,
1509
        exc_val: BaseException | None,
1510
        exc_tb: TracebackType | None,
1511
    ) -> None:
1512
        self.release()
10✔
1513

1514
    @property
10✔
1515
    def total_tokens(self) -> float:
7✔
1516
        return self._total_tokens
9✔
1517

1518
    @total_tokens.setter
10✔
1519
    def total_tokens(self, value: float) -> None:
7✔
1520
        if not isinstance(value, int) and not math.isinf(value):
10✔
1521
            raise TypeError("total_tokens must be an int or math.inf")
9✔
1522
        if value < 1:
10✔
1523
            raise ValueError("total_tokens must be >= 1")
9✔
1524

1525
        old_value = self._total_tokens
10✔
1526
        self._total_tokens = value
10✔
1527
        events = []
10✔
1528
        for event in self._wait_queue.values():
10✔
1529
            if value <= old_value:
9✔
1530
                break
×
1531

1532
            if not event.is_set():
9✔
1533
                events.append(event)
9✔
1534
                old_value += 1
9✔
1535

1536
        for event in events:
10✔
1537
            event.set()
9✔
1538

1539
    @property
10✔
1540
    def borrowed_tokens(self) -> int:
7✔
1541
        return len(self._borrowers)
9✔
1542

1543
    @property
10✔
1544
    def available_tokens(self) -> float:
7✔
1545
        return self._total_tokens - len(self._borrowers)
9✔
1546

1547
    def acquire_nowait(self) -> None:
10✔
1548
        self.acquire_on_behalf_of_nowait(current_task())
×
1549

1550
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
1551
        if borrower in self._borrowers:
10✔
1552
            raise RuntimeError(
9✔
1553
                "this borrower is already holding one of this CapacityLimiter's "
1554
                "tokens"
1555
            )
1556

1557
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
1558
            raise WouldBlock
9✔
1559

1560
        self._borrowers.add(borrower)
10✔
1561

1562
    async def acquire(self) -> None:
10✔
1563
        return await self.acquire_on_behalf_of(current_task())
10✔
1564

1565
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
1566
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1567
        try:
10✔
1568
            self.acquire_on_behalf_of_nowait(borrower)
10✔
1569
        except WouldBlock:
9✔
1570
            event = asyncio.Event()
9✔
1571
            self._wait_queue[borrower] = event
9✔
1572
            try:
9✔
1573
                await event.wait()
9✔
1574
            except BaseException:
×
1575
                self._wait_queue.pop(borrower, None)
×
1576
                raise
×
1577

1578
            self._borrowers.add(borrower)
9✔
1579
        else:
1580
            try:
10✔
1581
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1582
            except BaseException:
9✔
1583
                self.release()
9✔
1584
                raise
9✔
1585

1586
    def release(self) -> None:
10✔
1587
        self.release_on_behalf_of(current_task())
10✔
1588

1589
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
1590
        try:
10✔
1591
            self._borrowers.remove(borrower)
10✔
1592
        except KeyError:
9✔
1593
            raise RuntimeError(
9✔
1594
                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1595
            ) from None
1596

1597
        # Notify the next task in line if this limiter has free capacity now
1598
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
1599
            event = self._wait_queue.popitem(last=False)[1]
9✔
1600
            event.set()
9✔
1601

1602
    def statistics(self) -> CapacityLimiterStatistics:
10✔
1603
        return CapacityLimiterStatistics(
9✔
1604
            self.borrowed_tokens,
1605
            self.total_tokens,
1606
            tuple(self._borrowers),
1607
            len(self._wait_queue),
1608
        )
1609

1610

1611
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
1612

1613

1614
#
1615
# Operating system signals
1616
#
1617

1618

1619
class _SignalReceiver:
10✔
1620
    def __init__(self, signals: tuple[Signals, ...]):
10✔
1621
        self._signals = signals
8✔
1622
        self._loop = get_running_loop()
8✔
1623
        self._signal_queue: Deque[Signals] = deque()
8✔
1624
        self._future: asyncio.Future = asyncio.Future()
8✔
1625
        self._handled_signals: set[Signals] = set()
8✔
1626

1627
    def _deliver(self, signum: Signals) -> None:
10✔
1628
        self._signal_queue.append(signum)
8✔
1629
        if not self._future.done():
8✔
1630
            self._future.set_result(None)
8✔
1631

1632
    def __enter__(self) -> _SignalReceiver:
10✔
1633
        for sig in set(self._signals):
8✔
1634
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
1635
            self._handled_signals.add(sig)
8✔
1636

1637
        return self
8✔
1638

1639
    def __exit__(
10✔
1640
        self,
1641
        exc_type: type[BaseException] | None,
1642
        exc_val: BaseException | None,
1643
        exc_tb: TracebackType | None,
1644
    ) -> bool | None:
1645
        for sig in self._handled_signals:
8✔
1646
            self._loop.remove_signal_handler(sig)
8✔
1647
        return None
8✔
1648

1649
    def __aiter__(self) -> _SignalReceiver:
10✔
1650
        return self
8✔
1651

1652
    async def __anext__(self) -> Signals:
10✔
1653
        await AsyncIOBackend.checkpoint()
8✔
1654
        if not self._signal_queue:
8✔
1655
            self._future = asyncio.Future()
×
1656
            await self._future
×
1657

1658
        return self._signal_queue.popleft()
8✔
1659

1660

1661
#
1662
# Testing and debugging
1663
#
1664

1665

1666
def _create_task_info(task: asyncio.Task) -> TaskInfo:
10✔
1667
    task_state = _task_states.get(task)
10✔
1668
    if task_state is None:
10✔
1669
        name = task.get_name() if _native_task_names else None
10✔
1670
        parent_id = None
10✔
1671
    else:
1672
        name = task_state.name
10✔
1673
        parent_id = task_state.parent_id
10✔
1674

1675
    return TaskInfo(id(task), parent_id, name, get_coro(task))
10✔
1676

1677

1678
async def _shutdown_default_executor(loop: asyncio.BaseEventLoop) -> None:
10✔
1679
    """Schedule the shutdown of the default executor.
1680
    BaseEventLoop.shutdown_default_executor was introduced in Python 3.9.
1681
    This function is an adapted version of the method from Python 3.11.
1682
    It's used in TestRunner.close only if python < 3.9.
1683
    """
1684

1685
    def _do_shutdown(
4✔
1686
        loop_: asyncio.BaseEventLoop, future: asyncio.futures.Future
1687
    ) -> None:
1688
        try:
4✔
1689
            loop_._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
4✔
1690
            loop_.call_soon_threadsafe(future.set_result, None)
4✔
1691
        except Exception as ex:
×
1692
            loop_.call_soon_threadsafe(future.set_exception, ex)
×
1693

1694
    if loop._default_executor is None:  # type: ignore[attr-defined]
4✔
1695
        return
4✔
1696
    future = loop.create_future()
4✔
1697
    thread = threading.Thread(
4✔
1698
        target=_do_shutdown,
1699
        args=(
1700
            loop,
1701
            future,
1702
        ),
1703
    )
1704
    thread.start()
4✔
1705
    try:
4✔
1706
        await future
4✔
1707
    finally:
1708
        thread.join()
4✔
1709

1710

1711
class TestRunner(abc.TestRunner):
10✔
1712
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
1713

1714
    def __init__(
10✔
1715
        self,
1716
        debug: bool = False,
1717
        use_uvloop: bool = False,
1718
        policy: asyncio.AbstractEventLoopPolicy | None = None,
1719
    ):
1720
        self._exceptions: list[BaseException] = []
10✔
1721
        _maybe_set_event_loop_policy(policy, use_uvloop)
10✔
1722
        self._loop = asyncio.new_event_loop()
10✔
1723
        self._loop.set_debug(debug)
10✔
1724
        self._loop.set_exception_handler(self._exception_handler)
10✔
1725
        self._runner_task: asyncio.Task | None = None
10✔
1726
        asyncio.set_event_loop(self._loop)
10✔
1727

1728
    def _cancel_all_tasks(self) -> None:
10✔
1729
        to_cancel = all_tasks(self._loop)
10✔
1730
        if not to_cancel:
10✔
1731
            return
7✔
1732

1733
        for task in to_cancel:
10✔
1734
            task.cancel()
10✔
1735

1736
        self._loop.run_until_complete(
10✔
1737
            asyncio.gather(*to_cancel, return_exceptions=True)
1738
        )
1739

1740
        for task in to_cancel:
10✔
1741
            if task.cancelled():
10✔
1742
                continue
10✔
1743
            if task.exception() is not None:
8✔
1744
                raise cast(BaseException, task.exception())
×
1745

1746
    def _exception_handler(
10✔
1747
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
1748
    ) -> None:
1749
        if isinstance(context.get("exception"), Exception):
10✔
1750
            self._exceptions.append(context["exception"])
10✔
1751
        else:
1752
            loop.default_exception_handler(context)
10✔
1753

1754
    def _raise_async_exceptions(self) -> None:
10✔
1755
        # Re-raise any exceptions raised in asynchronous callbacks
1756
        if self._exceptions:
10✔
1757
            exceptions, self._exceptions = self._exceptions, []
10✔
1758
            if len(exceptions) == 1:
10✔
1759
                raise exceptions[0]
10✔
1760
            elif exceptions:
×
1761
                raise BaseExceptionGroup(
×
1762
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
1763
                )
1764

1765
    @staticmethod
10✔
1766
    async def _run_tests_and_fixtures(
7✔
1767
        receive_stream: MemoryObjectReceiveStream[
1768
            tuple[Coroutine[Any, Any, T_Retval], Future[T_Retval]]
1769
        ],
1770
    ) -> None:
1771
        with receive_stream:
10✔
1772
            async for coro, future in receive_stream:
10✔
1773
                try:
10✔
1774
                    retval = await coro
10✔
1775
                except BaseException as exc:
10✔
1776
                    if not future.cancelled():
10✔
1777
                        future.set_exception(exc)
10✔
1778
                else:
1779
                    if not future.cancelled():
10✔
1780
                        future.set_result(retval)
10✔
1781

1782
    async def _call_in_runner_task(
10✔
1783
        self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
1784
    ) -> T_Retval:
1785
        if not self._runner_task:
10✔
1786
            self._send_stream, receive_stream = create_memory_object_stream(1)
10✔
1787
            self._runner_task = self._loop.create_task(
10✔
1788
                self._run_tests_and_fixtures(receive_stream)
1789
            )
1790

1791
        coro = func(*args, **kwargs)
10✔
1792
        future: asyncio.Future[T_Retval] = self._loop.create_future()
10✔
1793
        self._send_stream.send_nowait((coro, future))
10✔
1794
        return await future
10✔
1795

1796
    def close(self) -> None:
10✔
1797
        try:
10✔
1798
            if self._runner_task is not None:
10✔
1799
                self._runner_task = None
10✔
1800
                self._loop.run_until_complete(self._send_stream.aclose())
10✔
1801
                del self._send_stream
10✔
1802

1803
            self._cancel_all_tasks()
10✔
1804
            self._loop.run_until_complete(self._loop.shutdown_asyncgens())
10✔
1805
            if hasattr(self._loop, "shutdown_default_executor"):
10✔
1806
                # asyncio in Python >= 3.9 or uvloop >= 0.15.0
1807
                self._loop.run_until_complete(self._loop.shutdown_default_executor())
9✔
1808
            elif isinstance(self._loop, asyncio.BaseEventLoop) and hasattr(
4✔
1809
                self._loop, "_default_executor"
1810
            ):
1811
                # asyncio in Python < 3.9
1812
                self._loop.run_until_complete(_shutdown_default_executor(self._loop))
4✔
1813
        finally:
1814
            asyncio.set_event_loop(None)
10✔
1815
            self._loop.close()
10✔
1816

1817
    def run_asyncgen_fixture(
10✔
1818
        self,
1819
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
1820
        kwargs: dict[str, Any],
1821
    ) -> Iterable[T_Retval]:
1822
        asyncgen = fixture_func(**kwargs)
10✔
1823
        fixturevalue: T_Retval = self._loop.run_until_complete(
10✔
1824
            self._call_in_runner_task(asyncgen.asend, None)
1825
        )
1826
        self._raise_async_exceptions()
10✔
1827

1828
        yield fixturevalue
10✔
1829

1830
        try:
10✔
1831
            self._loop.run_until_complete(
10✔
1832
                self._call_in_runner_task(asyncgen.asend, None)
1833
            )
1834
        except StopAsyncIteration:
10✔
1835
            self._raise_async_exceptions()
10✔
1836
        else:
1837
            self._loop.run_until_complete(asyncgen.aclose())
×
1838
            raise RuntimeError("Async generator fixture did not stop")
×
1839

1840
    def run_fixture(
10✔
1841
        self,
1842
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
1843
        kwargs: dict[str, Any],
1844
    ) -> T_Retval:
1845
        retval = self._loop.run_until_complete(
10✔
1846
            self._call_in_runner_task(fixture_func, **kwargs)
1847
        )
1848
        self._raise_async_exceptions()
10✔
1849
        return retval
10✔
1850

1851
    def run_test(
10✔
1852
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
1853
    ) -> None:
1854
        try:
10✔
1855
            self._loop.run_until_complete(
10✔
1856
                self._call_in_runner_task(test_func, **kwargs)
1857
            )
1858
        except Exception as exc:
10✔
1859
            self._exceptions.append(exc)
×
1860

1861
        self._raise_async_exceptions()
10✔
1862

1863

1864
class AsyncIOBackend(AsyncBackend):
10✔
1865
    @classmethod
10✔
1866
    def run(
7✔
1867
        cls,
1868
        func: Callable[..., Awaitable[T_Retval]],
1869
        args: tuple,
1870
        kwargs: dict[str, Any],
1871
        options: dict[str, Any],
1872
    ) -> T_Retval:
1873
        @wraps(func)
10✔
1874
        async def wrapper() -> T_Retval:
7✔
1875
            task = cast(asyncio.Task, current_task())
10✔
1876
            task_state = TaskState(None, get_callable_name(func), None)
10✔
1877
            _task_states[task] = task_state
10✔
1878
            if _native_task_names:
10✔
1879
                task.set_name(task_state.name)
7✔
1880

1881
            try:
10✔
1882
                return await func(*args)
10✔
1883
            finally:
1884
                del _task_states[task]
10✔
1885

1886
        debug = options.get("debug", False)
10✔
1887
        policy = options.get("policy", None)
10✔
1888
        use_uvloop = options.get("use_uvloop", False)
10✔
1889
        _maybe_set_event_loop_policy(policy, use_uvloop)
10✔
1890
        return native_run(wrapper(), debug=debug)
10✔
1891

1892
    @classmethod
10✔
1893
    def current_token(cls) -> object:
7✔
1894
        return get_running_loop()
10✔
1895

1896
    @classmethod
10✔
1897
    def current_time(cls) -> float:
7✔
1898
        return get_running_loop().time()
10✔
1899

1900
    @classmethod
10✔
1901
    def cancelled_exception_class(cls) -> type[BaseException]:
7✔
1902
        return CancelledError
10✔
1903

1904
    @classmethod
10✔
1905
    async def checkpoint(cls) -> None:
7✔
1906
        await sleep(0)
10✔
1907

1908
    @classmethod
10✔
1909
    async def checkpoint_if_cancelled(cls) -> None:
7✔
1910
        task = current_task()
10✔
1911
        if task is None:
10✔
1912
            return
×
1913

1914
        try:
10✔
1915
            cancel_scope = _task_states[task].cancel_scope
10✔
1916
        except KeyError:
10✔
1917
            return
10✔
1918

1919
        while cancel_scope:
10✔
1920
            if cancel_scope.cancel_called:
10✔
1921
                await sleep(0)
10✔
1922
            elif cancel_scope.shield:
10✔
1923
                break
9✔
1924
            else:
1925
                cancel_scope = cancel_scope._parent_scope
10✔
1926

1927
    @classmethod
10✔
1928
    async def cancel_shielded_checkpoint(cls) -> None:
7✔
1929
        with CancelScope(shield=True):
10✔
1930
            await sleep(0)
10✔
1931

1932
    @classmethod
10✔
1933
    async def sleep(cls, delay: float) -> None:
7✔
1934
        await sleep(delay)
10✔
1935

1936
    @classmethod
10✔
1937
    def create_cancel_scope(
10✔
1938
        cls, *, deadline: float = math.inf, shield: bool = False
1939
    ) -> CancelScope:
1940
        return CancelScope(deadline=deadline, shield=shield)
10✔
1941

1942
    @classmethod
10✔
1943
    def current_effective_deadline(cls) -> float:
7✔
1944
        try:
9✔
1945
            cancel_scope = _task_states[
9✔
1946
                current_task()  # type: ignore[index]
1947
            ].cancel_scope
1948
        except KeyError:
×
1949
            return math.inf
×
1950

1951
        deadline = math.inf
9✔
1952
        while cancel_scope:
9✔
1953
            deadline = min(deadline, cancel_scope.deadline)
9✔
1954
            if cancel_scope._cancel_called:
9✔
1955
                deadline = -math.inf
9✔
1956
                break
9✔
1957
            elif cancel_scope.shield:
9✔
1958
                break
9✔
1959
            else:
1960
                cancel_scope = cancel_scope._parent_scope
9✔
1961

1962
        return deadline
9✔
1963

1964
    @classmethod
10✔
1965
    def create_task_group(cls) -> abc.TaskGroup:
7✔
1966
        return TaskGroup()
10✔
1967

1968
    @classmethod
10✔
1969
    def create_event(cls) -> abc.Event:
7✔
1970
        return Event()
10✔
1971

1972
    @classmethod
10✔
1973
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
7✔
1974
        return CapacityLimiter(total_tokens)
9✔
1975

1976
    @classmethod
10✔
1977
    async def run_sync_in_worker_thread(
10✔
1978
        cls,
1979
        func: Callable[..., T_Retval],
1980
        args: tuple[Any, ...],
1981
        cancellable: bool = False,
1982
        limiter: abc.CapacityLimiter | None = None,
1983
    ) -> T_Retval:
1984
        await cls.checkpoint()
10✔
1985

1986
        # If this is the first run in this event loop thread, set up the necessary
1987
        # variables
1988
        try:
10✔
1989
            idle_workers = _threadpool_idle_workers.get()
10✔
1990
            workers = _threadpool_workers.get()
10✔
1991
        except LookupError:
10✔
1992
            idle_workers = deque()
10✔
1993
            workers = set()
10✔
1994
            _threadpool_idle_workers.set(idle_workers)
10✔
1995
            _threadpool_workers.set(workers)
10✔
1996

1997
        async with (limiter or cls.current_default_thread_limiter()):
10✔
1998
            with CancelScope(shield=not cancellable):
10✔
1999
                future: asyncio.Future = asyncio.Future()
10✔
2000
                root_task = find_root_task()
10✔
2001
                if not idle_workers:
10✔
2002
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2003
                    worker.start()
10✔
2004
                    workers.add(worker)
10✔
2005
                    root_task.add_done_callback(worker.stop)
10✔
2006
                else:
2007
                    worker = idle_workers.pop()
10✔
2008

2009
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2010
                    # seconds or longer
2011
                    now = cls.current_time()
10✔
2012
                    while idle_workers:
10✔
2013
                        if (
9✔
2014
                            now - idle_workers[0].idle_since
2015
                            < WorkerThread.MAX_IDLE_TIME
2016
                        ):
2017
                            break
9✔
2018

2019
                        expired_worker = idle_workers.popleft()
×
2020
                        expired_worker.root_task.remove_done_callback(
×
2021
                            expired_worker.stop
2022
                        )
2023
                        expired_worker.stop()
×
2024

2025
                context = copy_context()
10✔
2026
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2027
                worker.queue.put_nowait((context, func, args, future))
10✔
2028
                return await future
10✔
2029

2030
    @classmethod
10✔
2031
    def run_async_from_thread(
7✔
2032
        cls,
2033
        func: Callable[..., Awaitable[T_Retval]],
2034
        args: tuple[Any, ...],
2035
        token: object,
2036
    ) -> T_Retval:
2037
        loop = cast(AbstractEventLoop, token)
10✔
2038
        f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
10✔
2039
            func(*args), loop
2040
        )
2041
        return f.result()
10✔
2042

2043
    @classmethod
10✔
2044
    def run_sync_from_thread(
7✔
2045
        cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
2046
    ) -> T_Retval:
2047
        @wraps(func)
10✔
2048
        def wrapper() -> None:
7✔
2049
            try:
10✔
2050
                f.set_result(func(*args))
10✔
2051
            except BaseException as exc:
10✔
2052
                f.set_exception(exc)
10✔
2053
                if not isinstance(exc, Exception):
10✔
2054
                    raise
×
2055

2056
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2057
        loop = cast(AbstractEventLoop, token)
10✔
2058
        loop.call_soon_threadsafe(wrapper)
10✔
2059
        return f.result()
10✔
2060

2061
    @classmethod
10✔
2062
    def create_blocking_portal(cls) -> abc.BlockingPortal:
7✔
2063
        return BlockingPortal()
10✔
2064

2065
    @classmethod
10✔
2066
    async def open_process(
10✔
2067
        cls,
2068
        command: str | bytes | Sequence[str | bytes],
2069
        *,
2070
        shell: bool,
2071
        stdin: int | IO[Any] | None,
2072
        stdout: int | IO[Any] | None,
2073
        stderr: int | IO[Any] | None,
2074
        cwd: str | bytes | PathLike | None = None,
2075
        env: Mapping[str, str] | None = None,
2076
        start_new_session: bool = False,
2077
    ) -> Process:
2078
        await cls.checkpoint()
8✔
2079
        if shell:
8✔
2080
            process = await asyncio.create_subprocess_shell(
8✔
2081
                cast("str | bytes", command),
2082
                stdin=stdin,
2083
                stdout=stdout,
2084
                stderr=stderr,
2085
                cwd=cwd,
2086
                env=env,
2087
                start_new_session=start_new_session,
2088
            )
2089
        else:
2090
            process = await asyncio.create_subprocess_exec(
8✔
2091
                *command,
2092
                stdin=stdin,
2093
                stdout=stdout,
2094
                stderr=stderr,
2095
                cwd=cwd,
2096
                env=env,
2097
                start_new_session=start_new_session,
2098
            )
2099

2100
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
8✔
2101
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
8✔
2102
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
8✔
2103
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
8✔
2104

2105
    @classmethod
10✔
2106
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
7✔
2107
        kwargs = (
8✔
2108
            {"name": "AnyIO process pool shutdown task"} if _native_task_names else {}
2109
        )
2110
        create_task(_shutdown_process_pool_on_exit(workers), **kwargs)
8✔
2111
        find_root_task().add_done_callback(
8✔
2112
            partial(_forcibly_shutdown_process_pool_on_exit, workers)
2113
        )
2114

2115
    @classmethod
10✔
2116
    async def connect_tcp(
10✔
2117
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2118
    ) -> abc.SocketStream:
2119
        transport, protocol = cast(
10✔
2120
            Tuple[asyncio.Transport, StreamProtocol],
2121
            await get_running_loop().create_connection(
2122
                StreamProtocol, host, port, local_addr=local_address
2123
            ),
2124
        )
2125
        transport.pause_reading()
10✔
2126
        return SocketStream(transport, protocol)
10✔
2127

2128
    @classmethod
10✔
2129
    async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
7✔
2130
        await cls.checkpoint()
7✔
2131
        loop = get_running_loop()
7✔
2132
        raw_socket = socket.socket(socket.AF_UNIX)
7✔
2133
        raw_socket.setblocking(False)
7✔
2134
        while True:
5✔
2135
            try:
7✔
2136
                raw_socket.connect(path)
7✔
2137
            except BlockingIOError:
7✔
2138
                f: asyncio.Future = asyncio.Future()
×
2139
                loop.add_writer(raw_socket, f.set_result, None)
×
2140
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2141
                await f
×
2142
            except BaseException:
7✔
2143
                raw_socket.close()
7✔
2144
                raise
7✔
2145
            else:
2146
                return UNIXSocketStream(raw_socket)
7✔
2147

2148
    @classmethod
10✔
2149
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
7✔
2150
        return TCPSocketListener(sock)
10✔
2151

2152
    @classmethod
10✔
2153
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
7✔
2154
        return UNIXSocketListener(sock)
7✔
2155

2156
    @classmethod
10✔
2157
    async def create_udp_socket(
7✔
2158
        cls,
2159
        family: AddressFamily,
2160
        local_address: IPSockAddrType | None,
2161
        remote_address: IPSockAddrType | None,
2162
        reuse_port: bool,
2163
    ) -> UDPSocket | ConnectedUDPSocket:
2164
        transport, protocol = await get_running_loop().create_datagram_endpoint(
9✔
2165
            DatagramProtocol,
2166
            local_addr=local_address,
2167
            remote_addr=remote_address,
2168
            family=family,
2169
            reuse_port=reuse_port,
2170
        )
2171
        if protocol.exception:
9✔
2172
            transport.close()
×
2173
            raise protocol.exception
×
2174

2175
        if not remote_address:
9✔
2176
            return UDPSocket(transport, protocol)
9✔
2177
        else:
2178
            return ConnectedUDPSocket(transport, protocol)
9✔
2179

2180
    @classmethod
10✔
2181
    async def create_unix_datagram_socket(  # type: ignore[override]
7✔
2182
        cls, raw_socket: socket.socket, remote_path: str | None
2183
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2184
        await cls.checkpoint()
7✔
2185
        loop = get_running_loop()
7✔
2186

2187
        if remote_path:
7✔
2188
            while True:
5✔
2189
                try:
7✔
2190
                    raw_socket.connect(remote_path)
7✔
2191
                except BlockingIOError:
×
2192
                    f: asyncio.Future = asyncio.Future()
×
2193
                    loop.add_writer(raw_socket, f.set_result, None)
×
2194
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2195
                    await f
×
2196
                except BaseException:
×
2197
                    raw_socket.close()
×
2198
                    raise
×
2199
                else:
2200
                    return ConnectedUNIXDatagramSocket(raw_socket)
7✔
2201
        else:
2202
            return UNIXDatagramSocket(raw_socket)
7✔
2203

2204
    @classmethod
10✔
2205
    async def getaddrinfo(
10✔
2206
        cls,
2207
        host: bytes | str | None,
2208
        port: str | int | None,
2209
        *,
2210
        family: int | AddressFamily = 0,
2211
        type: int | SocketKind = 0,
2212
        proto: int = 0,
2213
        flags: int = 0,
2214
    ) -> list[
2215
        tuple[
2216
            AddressFamily,
2217
            SocketKind,
2218
            int,
2219
            str,
2220
            tuple[str, int] | tuple[str, int, int, int],
2221
        ]
2222
    ]:
2223
        return await get_running_loop().getaddrinfo(
10✔
2224
            host, port, family=family, type=type, proto=proto, flags=flags
2225
        )
2226

2227
    @classmethod
10✔
2228
    async def getnameinfo(
10✔
2229
        cls, sockaddr: IPSockAddrType, flags: int = 0
2230
    ) -> tuple[str, str]:
2231
        return await get_running_loop().getnameinfo(sockaddr, flags)
9✔
2232

2233
    @classmethod
10✔
2234
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
7✔
2235
        await cls.checkpoint()
×
2236
        try:
×
2237
            read_events = _read_events.get()
×
2238
        except LookupError:
×
2239
            read_events = {}
×
2240
            _read_events.set(read_events)
×
2241

2242
        if read_events.get(sock):
×
2243
            raise BusyResourceError("reading from") from None
×
2244

2245
        loop = get_running_loop()
×
2246
        event = read_events[sock] = asyncio.Event()
×
2247
        loop.add_reader(sock, event.set)
×
2248
        try:
×
2249
            await event.wait()
×
2250
        finally:
2251
            if read_events.pop(sock, None) is not None:
×
2252
                loop.remove_reader(sock)
×
2253
                readable = True
×
2254
            else:
2255
                readable = False
×
2256

2257
        if not readable:
×
2258
            raise ClosedResourceError
×
2259

2260
    @classmethod
10✔
2261
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
7✔
2262
        await cls.checkpoint()
×
2263
        try:
×
2264
            write_events = _write_events.get()
×
2265
        except LookupError:
×
2266
            write_events = {}
×
2267
            _write_events.set(write_events)
×
2268

2269
        if write_events.get(sock):
×
2270
            raise BusyResourceError("writing to") from None
×
2271

2272
        loop = get_running_loop()
×
2273
        event = write_events[sock] = asyncio.Event()
×
2274
        loop.add_writer(sock.fileno(), event.set)
×
2275
        try:
×
2276
            await event.wait()
×
2277
        finally:
2278
            if write_events.pop(sock, None) is not None:
×
2279
                loop.remove_writer(sock)
×
2280
                writable = True
×
2281
            else:
2282
                writable = False
×
2283

2284
        if not writable:
×
2285
            raise ClosedResourceError
×
2286

2287
    @classmethod
10✔
2288
    def current_default_thread_limiter(cls) -> CapacityLimiter:
7✔
2289
        try:
10✔
2290
            return _default_thread_limiter.get()
10✔
2291
        except LookupError:
10✔
2292
            limiter = CapacityLimiter(40)
10✔
2293
            _default_thread_limiter.set(limiter)
10✔
2294
            return limiter
10✔
2295

2296
    @classmethod
10✔
2297
    def open_signal_receiver(
7✔
2298
        cls, *signals: Signals
2299
    ) -> ContextManager[AsyncIterator[Signals]]:
2300
        return _SignalReceiver(signals)
8✔
2301

2302
    @classmethod
10✔
2303
    def get_current_task(cls) -> TaskInfo:
7✔
2304
        return _create_task_info(current_task())  # type: ignore[arg-type]
10✔
2305

2306
    @classmethod
10✔
2307
    def get_running_tasks(cls) -> list[TaskInfo]:
7✔
2308
        return [_create_task_info(task) for task in all_tasks() if not task.done()]
10✔
2309

2310
    @classmethod
10✔
2311
    async def wait_all_tasks_blocked(cls) -> None:
7✔
2312
        await cls.checkpoint()
10✔
2313
        this_task = current_task()
10✔
2314
        while True:
7✔
2315
            for task in all_tasks():
10✔
2316
                if task is this_task:
10✔
2317
                    continue
10✔
2318

2319
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2320
                if waiter is None or waiter.done():
10✔
2321
                    await sleep(0.1)
10✔
2322
                    break
10✔
2323
            else:
2324
                return
10✔
2325

2326
    @classmethod
10✔
2327
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
7✔
2328
        return TestRunner(**options)
10✔
2329

2330

2331
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc