• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / anyio / 4854167299

pending completion
4854167299

Pull #558

github

GitHub
Merge 4123b805c into 69be26e81
Pull Request #558: Turned TaskStatus into a protocol

13 of 15 new or added lines in 2 files covered. (86.67%)

1 existing line in 1 file now uncovered.

3927 of 4350 relevant lines covered (90.28%)

8.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.92
/src/anyio/_backends/_asyncio.py
1
from __future__ import annotations
10✔
2

3
import array
10✔
4
import asyncio
10✔
5
import concurrent.futures
10✔
6
import math
10✔
7
import socket
10✔
8
import sys
10✔
9
import threading
10✔
10
from asyncio import (
10✔
11
    AbstractEventLoop,
12
    CancelledError,
13
    all_tasks,
14
    create_task,
15
    current_task,
16
    get_running_loop,
17
    sleep,
18
)
19
from asyncio import run as native_run
10✔
20
from asyncio.base_events import _run_until_complete_cb  # type: ignore[attr-defined]
10✔
21
from collections import OrderedDict, deque
10✔
22
from collections.abc import AsyncIterator, Iterable
10✔
23
from concurrent.futures import Future
10✔
24
from contextvars import Context, copy_context
10✔
25
from dataclasses import dataclass
10✔
26
from functools import partial, wraps
10✔
27
from inspect import CORO_RUNNING, CORO_SUSPENDED, getcoroutinestate
10✔
28
from io import IOBase
10✔
29
from os import PathLike
10✔
30
from queue import Queue
10✔
31
from signal import Signals
10✔
32
from socket import AddressFamily, SocketKind
10✔
33
from threading import Thread
10✔
34
from types import TracebackType
10✔
35
from typing import (
10✔
36
    IO,
37
    Any,
38
    AsyncGenerator,
39
    Awaitable,
40
    Callable,
41
    Collection,
42
    ContextManager,
43
    Coroutine,
44
    Deque,
45
    Generator,
46
    Iterator,
47
    Mapping,
48
    Optional,
49
    Sequence,
50
    Tuple,
51
    TypeVar,
52
    cast,
53
)
54
from weakref import WeakKeyDictionary
10✔
55

56
import sniffio
10✔
57

58
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
10✔
59
from .._core._eventloop import claim_worker_thread
10✔
60
from .._core._exceptions import (
10✔
61
    BrokenResourceError,
62
    BusyResourceError,
63
    ClosedResourceError,
64
    EndOfStream,
65
    WouldBlock,
66
)
67
from .._core._sockets import convert_ipv6_sockaddr
10✔
68
from .._core._streams import create_memory_object_stream
10✔
69
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
10✔
70
from .._core._synchronization import Event as BaseEvent
10✔
71
from .._core._synchronization import ResourceGuard
10✔
72
from .._core._tasks import CancelScope as BaseCancelScope
10✔
73
from ..abc import (
10✔
74
    AsyncBackend,
75
    IPSockAddrType,
76
    SocketListener,
77
    UDPPacketType,
78
    UNIXDatagramPacketType,
79
)
80
from ..abc._tasks import T_contra
10✔
81
from ..lowlevel import RunVar
10✔
82
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10✔
83

84
if sys.version_info < (3, 11):
10✔
85
    from exceptiongroup import BaseExceptionGroup, ExceptionGroup
7✔
86

87
if sys.version_info >= (3, 8):
10✔
88

89
    def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
7✔
90
        return task.get_coro()
7✔
91

92
else:
93

94
    def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
3✔
95
        return task._coro
3✔
96

97

98
T_Retval = TypeVar("T_Retval")
10✔
99

100
# Check whether there is native support for task names in asyncio (3.8+)
101
_native_task_names = hasattr(asyncio.Task, "get_name")
10✔
102

103
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
10✔
104

105

106
def find_root_task() -> asyncio.Task:
10✔
107
    root_task = _root_task.get(None)
10✔
108
    if root_task is not None and not root_task.done():
10✔
109
        return root_task
10✔
110

111
    # Look for a task that has been started via run_until_complete()
112
    for task in all_tasks():
10✔
113
        if task._callbacks and not task.done():
10✔
114
            callbacks = [cb for cb, context in task._callbacks]
10✔
115
            for cb in callbacks:
10✔
116
                if (
10✔
117
                    cb is _run_until_complete_cb
118
                    or getattr(cb, "__module__", None) == "uvloop.loop"
119
                ):
120
                    _root_task.set(task)
10✔
121
                    return task
10✔
122

123
    # Look up the topmost task in the AnyIO task tree, if possible
124
    task = cast(asyncio.Task, current_task())
9✔
125
    state = _task_states.get(task)
9✔
126
    if state:
9✔
127
        cancel_scope = state.cancel_scope
9✔
128
        while cancel_scope and cancel_scope._parent_scope is not None:
9✔
129
            cancel_scope = cancel_scope._parent_scope
×
130

131
        if cancel_scope is not None:
9✔
132
            return cast(asyncio.Task, cancel_scope._host_task)
9✔
133

134
    return task
×
135

136

137
def get_callable_name(func: Callable) -> str:
10✔
138
    module = getattr(func, "__module__", None)
10✔
139
    qualname = getattr(func, "__qualname__", None)
10✔
140
    return ".".join([x for x in (module, qualname) if x])
10✔
141

142

143
#
144
# Event loop
145
#
146

147
_run_vars = (
10✔
148
    WeakKeyDictionary()
149
)  # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any]
150

151

152
def _task_started(task: asyncio.Task) -> bool:
10✔
153
    """Return ``True`` if the task has been started and has not finished."""
154
    coro = cast(Coroutine[Any, Any, Any], get_coro(task))
10✔
155
    try:
10✔
156
        return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
10✔
157
    except AttributeError:
×
158
        # task coro is async_genenerator_asend https://bugs.python.org/issue37771
159
        raise Exception(f"Cannot determine if task {task} has started or not")
×
160

161

162
def _maybe_set_event_loop_policy(
10✔
163
    policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool
164
) -> None:
165
    # On CPython, use uvloop when possible if no other policy has been given and if not
166
    # explicitly disabled
167
    if policy is None and use_uvloop and sys.implementation.name == "cpython":
10✔
168
        try:
×
169
            import uvloop
×
170
        except ImportError:
×
171
            pass
×
172
        else:
173
            # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier)
174
            if not hasattr(
×
175
                asyncio.AbstractEventLoop, "shutdown_default_executor"
176
            ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"):
177
                policy = uvloop.EventLoopPolicy()
×
178

179
    if policy is not None:
10✔
180
        asyncio.set_event_loop_policy(policy)
10✔
181

182

183
#
184
# Timeouts and cancellation
185
#
186

187

188
class CancelScope(BaseCancelScope):
10✔
189
    def __new__(
10✔
190
        cls, *, deadline: float = math.inf, shield: bool = False
191
    ) -> CancelScope:
192
        return object.__new__(cls)
10✔
193

194
    def __init__(self, deadline: float = math.inf, shield: bool = False):
10✔
195
        self._deadline = deadline
10✔
196
        self._shield = shield
10✔
197
        self._parent_scope: CancelScope | None = None
10✔
198
        self._cancel_called = False
10✔
199
        self._active = False
10✔
200
        self._timeout_handle: asyncio.TimerHandle | None = None
10✔
201
        self._cancel_handle: asyncio.Handle | None = None
10✔
202
        self._tasks: set[asyncio.Task] = set()
10✔
203
        self._host_task: asyncio.Task | None = None
10✔
204
        self._timeout_expired = False
10✔
205

206
    def __enter__(self) -> CancelScope:
10✔
207
        if self._active:
10✔
208
            raise RuntimeError(
×
209
                "Each CancelScope may only be used for a single 'with' block"
210
            )
211

212
        self._host_task = host_task = cast(asyncio.Task, current_task())
10✔
213
        self._tasks.add(host_task)
10✔
214
        try:
10✔
215
            task_state = _task_states[host_task]
10✔
216
        except KeyError:
10✔
217
            task_name = host_task.get_name() if _native_task_names else None
10✔
218
            task_state = TaskState(None, task_name, self)
10✔
219
            _task_states[host_task] = task_state
10✔
220
        else:
221
            self._parent_scope = task_state.cancel_scope
10✔
222
            task_state.cancel_scope = self
10✔
223

224
        self._timeout()
10✔
225
        self._active = True
10✔
226

227
        # Start cancelling the host task if the scope was cancelled before entering
228
        if self._cancel_called:
10✔
229
            self._deliver_cancellation()
10✔
230

231
        return self
10✔
232

233
    def __exit__(
10✔
234
        self,
235
        exc_type: type[BaseException] | None,
236
        exc_val: BaseException | None,
237
        exc_tb: TracebackType | None,
238
    ) -> bool | None:
239
        if not self._active:
10✔
240
            raise RuntimeError("This cancel scope is not active")
9✔
241
        if current_task() is not self._host_task:
10✔
242
            raise RuntimeError(
9✔
243
                "Attempted to exit cancel scope in a different task than it was "
244
                "entered in"
245
            )
246

247
        assert self._host_task is not None
10✔
248
        host_task_state = _task_states.get(self._host_task)
10✔
249
        if host_task_state is None or host_task_state.cancel_scope is not self:
10✔
250
            raise RuntimeError(
9✔
251
                "Attempted to exit a cancel scope that isn't the current tasks's "
252
                "current cancel scope"
253
            )
254

255
        self._active = False
10✔
256
        if self._timeout_handle:
10✔
257
            self._timeout_handle.cancel()
10✔
258
            self._timeout_handle = None
10✔
259

260
        self._tasks.remove(self._host_task)
10✔
261

262
        host_task_state.cancel_scope = self._parent_scope
10✔
263

264
        # Restart the cancellation effort in the farthest directly cancelled parent
265
        # scope if this one was shielded
266
        if self._shield:
10✔
267
            self._deliver_cancellation_to_parent()
10✔
268

269
        if exc_val is not None:
10✔
270
            exceptions = (
10✔
271
                exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val]
272
            )
273
            if all(isinstance(exc, CancelledError) for exc in exceptions):
10✔
274
                if self._timeout_expired:
10✔
275
                    return True
10✔
276
                elif not self._cancel_called:
10✔
277
                    # Task was cancelled natively
278
                    return None
10✔
279
                elif not self._parent_cancelled():
10✔
280
                    # This scope was directly cancelled
281
                    return True
10✔
282

283
        return None
10✔
284

285
    def _timeout(self) -> None:
10✔
286
        if self._deadline != math.inf:
10✔
287
            loop = get_running_loop()
10✔
288
            if loop.time() >= self._deadline:
10✔
289
                self._timeout_expired = True
10✔
290
                self.cancel()
10✔
291
            else:
292
                self._timeout_handle = loop.call_at(self._deadline, self._timeout)
10✔
293

294
    def _deliver_cancellation(self) -> None:
10✔
295
        """
296
        Deliver cancellation to directly contained tasks and nested cancel scopes.
297

298
        Schedule another run at the end if we still have tasks eligible for
299
        cancellation.
300
        """
301
        should_retry = False
10✔
302
        current = current_task()
10✔
303
        for task in self._tasks:
10✔
304
            if task._must_cancel:  # type: ignore[attr-defined]
10✔
305
                continue
9✔
306

307
            # The task is eligible for cancellation if it has started and is not in a
308
            # cancel scope shielded from this one
309
            cancel_scope = _task_states[task].cancel_scope
10✔
310
            while cancel_scope is not self:
10✔
311
                if cancel_scope is None or cancel_scope._shield:
10✔
312
                    break
10✔
313
                else:
314
                    cancel_scope = cancel_scope._parent_scope
10✔
315
            else:
316
                should_retry = True
10✔
317
                if task is not current and (
10✔
318
                    task is self._host_task or _task_started(task)
319
                ):
320
                    task.cancel()
10✔
321

322
        # Schedule another callback if there are still tasks left
323
        if should_retry:
10✔
324
            self._cancel_handle = get_running_loop().call_soon(
10✔
325
                self._deliver_cancellation
326
            )
327
        else:
328
            self._cancel_handle = None
10✔
329

330
    def _deliver_cancellation_to_parent(self) -> None:
10✔
331
        """Start cancellation effort in the farthest directly cancelled parent scope"""
332
        scope = self._parent_scope
10✔
333
        scope_to_cancel: CancelScope | None = None
10✔
334
        while scope is not None:
10✔
335
            if scope._cancel_called and scope._cancel_handle is None:
10✔
336
                scope_to_cancel = scope
9✔
337

338
            # No point in looking beyond any shielded scope
339
            if scope._shield:
10✔
340
                break
9✔
341

342
            scope = scope._parent_scope
10✔
343

344
        if scope_to_cancel is not None:
10✔
345
            scope_to_cancel._deliver_cancellation()
9✔
346

347
    def _parent_cancelled(self) -> bool:
10✔
348
        # Check whether any parent has been cancelled
349
        cancel_scope = self._parent_scope
10✔
350
        while cancel_scope is not None and not cancel_scope._shield:
10✔
351
            if cancel_scope._cancel_called:
10✔
352
                return True
9✔
353
            else:
354
                cancel_scope = cancel_scope._parent_scope
10✔
355

356
        return False
10✔
357

358
    def cancel(self) -> None:
10✔
359
        if not self._cancel_called:
10✔
360
            if self._timeout_handle:
10✔
361
                self._timeout_handle.cancel()
10✔
362
                self._timeout_handle = None
10✔
363

364
            self._cancel_called = True
10✔
365
            if self._host_task is not None:
10✔
366
                self._deliver_cancellation()
10✔
367

368
    @property
10✔
369
    def deadline(self) -> float:
7✔
370
        return self._deadline
9✔
371

372
    @deadline.setter
10✔
373
    def deadline(self, value: float) -> None:
7✔
374
        self._deadline = float(value)
9✔
375
        if self._timeout_handle is not None:
9✔
376
            self._timeout_handle.cancel()
9✔
377
            self._timeout_handle = None
9✔
378

379
        if self._active and not self._cancel_called:
9✔
380
            self._timeout()
9✔
381

382
    @property
10✔
383
    def cancel_called(self) -> bool:
7✔
384
        return self._cancel_called
10✔
385

386
    @property
10✔
387
    def shield(self) -> bool:
7✔
388
        return self._shield
10✔
389

390
    @shield.setter
10✔
391
    def shield(self, value: bool) -> None:
7✔
392
        if self._shield != value:
9✔
393
            self._shield = value
9✔
394
            if not value:
9✔
395
                self._deliver_cancellation_to_parent()
9✔
396

397

398
#
399
# Task states
400
#
401

402

403
class TaskState:
10✔
404
    """
405
    Encapsulates auxiliary task information that cannot be added to the Task instance
406
    itself because there are no guarantees about its implementation.
407
    """
408

409
    __slots__ = "parent_id", "name", "cancel_scope"
10✔
410

411
    def __init__(
10✔
412
        self, parent_id: int | None, name: str | None, cancel_scope: CancelScope | None
413
    ):
414
        self.parent_id = parent_id
10✔
415
        self.name = name
10✔
416
        self.cancel_scope = cancel_scope
10✔
417

418

419
_task_states = WeakKeyDictionary()  # type: WeakKeyDictionary[asyncio.Task, TaskState]
10✔
420

421

422
#
423
# Task groups
424
#
425

426

427
class _AsyncioTaskStatus(abc.TaskStatus):
10✔
428
    def __init__(self, future: asyncio.Future, parent_id: int):
10✔
429
        self._future = future
10✔
430
        self._parent_id = parent_id
10✔
431

432
    def started(self, value: T_contra | None = None) -> None:
10✔
433
        try:
10✔
434
            self._future.set_result(value)
10✔
435
        except asyncio.InvalidStateError:
9✔
436
            raise RuntimeError(
9✔
437
                "called 'started' twice on the same task status"
438
            ) from None
439

440
        task = cast(asyncio.Task, current_task())
10✔
441
        _task_states[task].parent_id = self._parent_id
10✔
442

443

444
def collapse_exception_group(excgroup: BaseExceptionGroup) -> BaseException:
10✔
445
    exceptions = list(excgroup.exceptions)
10✔
446
    modified = False
10✔
447
    for i, exc in enumerate(exceptions):
10✔
448
        if isinstance(exc, BaseExceptionGroup):
10✔
449
            new_exc = collapse_exception_group(exc)
9✔
450
            if new_exc is not exc:
9✔
451
                modified = True
9✔
452
                exceptions[i] = new_exc
9✔
453

454
    if len(exceptions) == 1:
10✔
455
        return exceptions[0]
10✔
456
    elif modified:
9✔
457
        return excgroup.derive(exceptions)
9✔
458
    else:
459
        return excgroup
9✔
460

461

462
def walk_exception_group(excgroup: BaseExceptionGroup) -> Iterator[BaseException]:
10✔
463
    for exc in excgroup.exceptions:
9✔
464
        if isinstance(exc, BaseExceptionGroup):
9✔
465
            yield from walk_exception_group(exc)
×
466
        else:
467
            yield exc
9✔
468

469

470
def is_anyio_cancelled_exc(exc: BaseException) -> bool:
10✔
471
    return isinstance(exc, CancelledError) and not exc.args
10✔
472

473

474
class TaskGroup(abc.TaskGroup):
10✔
475
    def __init__(self) -> None:
10✔
476
        self.cancel_scope: CancelScope = CancelScope()
10✔
477
        self._active = False
10✔
478
        self._exceptions: list[BaseException] = []
10✔
479

480
    async def __aenter__(self) -> TaskGroup:
10✔
481
        self.cancel_scope.__enter__()
10✔
482
        self._active = True
10✔
483
        return self
10✔
484

485
    async def __aexit__(
10✔
486
        self,
487
        exc_type: type[BaseException] | None,
488
        exc_val: BaseException | None,
489
        exc_tb: TracebackType | None,
490
    ) -> bool | None:
491
        ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
10✔
492
        if exc_val is not None:
10✔
493
            self.cancel_scope.cancel()
10✔
494
            self._exceptions.append(exc_val)
10✔
495

496
        while self.cancel_scope._tasks:
10✔
497
            try:
10✔
498
                await asyncio.wait(self.cancel_scope._tasks)
10✔
499
            except asyncio.CancelledError:
9✔
500
                self.cancel_scope.cancel()
9✔
501

502
        self._active = False
10✔
503
        if self._exceptions:
10✔
504
            exc: BaseException | None
505
            group = BaseExceptionGroup("multiple tasks failed", self._exceptions)
10✔
506
            if not self.cancel_scope._parent_cancelled():
10✔
507
                # If any exceptions other than AnyIO cancellation exceptions have been
508
                # received, raise those
509
                _, exc = group.split(is_anyio_cancelled_exc)
10✔
510
            elif all(is_anyio_cancelled_exc(e) for e in walk_exception_group(group)):
9✔
511
                # All tasks were cancelled by AnyIO
512
                exc = CancelledError()
9✔
513
            else:
514
                exc = group
9✔
515

516
            if isinstance(exc, BaseExceptionGroup):
10✔
517
                exc = collapse_exception_group(exc)
10✔
518

519
            if exc is not None and exc is not exc_val:
10✔
520
                raise exc
10✔
521

522
        return ignore_exception
10✔
523

524
    async def _run_wrapped_task(
10✔
525
        self, coro: Coroutine, task_status_future: asyncio.Future | None
526
    ) -> None:
527
        # This is the code path for Python 3.7 on which asyncio freaks out if a task
528
        # raises a BaseException.
529
        __traceback_hide__ = __tracebackhide__ = True  # noqa: F841
3✔
530
        task = cast(asyncio.Task, current_task())
3✔
531
        try:
3✔
532
            await coro
3✔
533
        except BaseException as exc:
3✔
534
            if task_status_future is None or task_status_future.done():
3✔
535
                self._exceptions.append(exc)
3✔
536
                self.cancel_scope.cancel()
3✔
537
            else:
538
                task_status_future.set_exception(exc)
3✔
539
        else:
540
            if task_status_future is not None and not task_status_future.done():
3✔
541
                task_status_future.set_exception(
3✔
542
                    RuntimeError("Child exited without calling task_status.started()")
543
                )
544
        finally:
545
            if task in self.cancel_scope._tasks:
3✔
546
                self.cancel_scope._tasks.remove(task)
3✔
547
                del _task_states[task]
3✔
548

549
    def _spawn(
10✔
550
        self,
551
        func: Callable[..., Awaitable[Any]],
552
        args: tuple,
553
        name: object,
554
        task_status_future: asyncio.Future | None = None,
555
    ) -> asyncio.Task:
556
        def task_done(_task: asyncio.Task) -> None:
10✔
557
            # This is the code path for Python 3.8+
558
            assert _task in self.cancel_scope._tasks
7✔
559
            self.cancel_scope._tasks.remove(_task)
7✔
560
            del _task_states[_task]
7✔
561

562
            try:
7✔
563
                exc = _task.exception()
7✔
564
            except CancelledError as e:
7✔
565
                while isinstance(e.__context__, CancelledError):
7✔
566
                    e = e.__context__
6✔
567

568
                exc = e
7✔
569

570
            if exc is not None:
7✔
571
                if task_status_future is None or task_status_future.done():
7✔
572
                    self._exceptions.append(exc)
7✔
573
                    self.cancel_scope.cancel()
7✔
574
                else:
575
                    task_status_future.set_exception(exc)
6✔
576
            elif task_status_future is not None and not task_status_future.done():
7✔
577
                task_status_future.set_exception(
6✔
578
                    RuntimeError("Child exited without calling task_status.started()")
579
                )
580

581
        if not self._active:
10✔
582
            raise RuntimeError(
10✔
583
                "This task group is not active; no new tasks can be started."
584
            )
585

586
        options = {}
10✔
587
        name = get_callable_name(func) if name is None else str(name)
10✔
588
        if _native_task_names:
10✔
589
            options["name"] = name
7✔
590

591
        kwargs = {}
10✔
592
        if task_status_future:
10✔
593
            parent_id = id(current_task())
10✔
594
            kwargs["task_status"] = _AsyncioTaskStatus(
10✔
595
                task_status_future, id(self.cancel_scope._host_task)
596
            )
597
        else:
598
            parent_id = id(self.cancel_scope._host_task)
10✔
599

600
        coro = func(*args, **kwargs)
10✔
601
        if not asyncio.iscoroutine(coro):
10✔
602
            raise TypeError(
9✔
603
                f"Expected an async function, but {func} appears to be synchronous"
604
            )
605

606
        foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame")
10✔
607
        if foreign_coro or sys.version_info < (3, 8):
10✔
608
            coro = self._run_wrapped_task(coro, task_status_future)
3✔
609

610
        task = create_task(coro, **options)
10✔
611
        if not foreign_coro and sys.version_info >= (3, 8):
10✔
612
            task.add_done_callback(task_done)
7✔
613

614
        # Make the spawned task inherit the task group's cancel scope
615
        _task_states[task] = TaskState(
10✔
616
            parent_id=parent_id, name=name, cancel_scope=self.cancel_scope
617
        )
618
        self.cancel_scope._tasks.add(task)
10✔
619
        return task
10✔
620

621
    def start_soon(
10✔
622
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
623
    ) -> None:
624
        self._spawn(func, args, name)
10✔
625

626
    async def start(
10✔
627
        self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
628
    ) -> None:
629
        future: asyncio.Future = asyncio.Future()
10✔
630
        task = self._spawn(func, args, name, future)
10✔
631

632
        # If the task raises an exception after sending a start value without a switch
633
        # point between, the task group is cancelled and this method never proceeds to
634
        # process the completed future. That's why we have to have a shielded cancel
635
        # scope here.
636
        with CancelScope(shield=True):
10✔
637
            try:
10✔
638
                return await future
10✔
639
            except CancelledError:
9✔
640
                task.cancel()
9✔
641
                raise
9✔
642

643

644
#
645
# Threads
646
#
647

648
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
10✔
649

650

651
class WorkerThread(Thread):
10✔
652
    MAX_IDLE_TIME = 10  # seconds
10✔
653

654
    def __init__(
10✔
655
        self,
656
        root_task: asyncio.Task,
657
        workers: set[WorkerThread],
658
        idle_workers: Deque[WorkerThread],
659
    ):
660
        super().__init__(name="AnyIO worker thread")
10✔
661
        self.root_task = root_task
10✔
662
        self.workers = workers
10✔
663
        self.idle_workers = idle_workers
10✔
664
        self.loop = root_task._loop
10✔
665
        self.queue: Queue[
10✔
666
            tuple[Context, Callable, tuple, asyncio.Future] | None
667
        ] = Queue(2)
668
        self.idle_since = AsyncIOBackend.current_time()
10✔
669
        self.stopping = False
10✔
670

671
    def _report_result(
10✔
672
        self, future: asyncio.Future, result: Any, exc: BaseException | None
673
    ) -> None:
674
        self.idle_since = AsyncIOBackend.current_time()
10✔
675
        if not self.stopping:
10✔
676
            self.idle_workers.append(self)
10✔
677

678
        if not future.cancelled():
10✔
679
            if exc is not None:
10✔
680
                if isinstance(exc, StopIteration):
10✔
681
                    new_exc = RuntimeError("coroutine raised StopIteration")
9✔
682
                    new_exc.__cause__ = exc
9✔
683
                    exc = new_exc
9✔
684

685
                future.set_exception(exc)
10✔
686
            else:
687
                future.set_result(result)
10✔
688

689
    def run(self) -> None:
10✔
690
        with claim_worker_thread(AsyncIOBackend, self.loop):
10✔
691
            while True:
7✔
692
                item = self.queue.get()
10✔
693
                if item is None:
10✔
694
                    # Shutdown command received
695
                    return
10✔
696

697
                context, func, args, future = item
10✔
698
                if not future.cancelled():
10✔
699
                    result = None
10✔
700
                    exception: BaseException | None = None
10✔
701
                    try:
10✔
702
                        result = context.run(func, *args)
10✔
703
                    except BaseException as exc:
10✔
704
                        exception = exc
10✔
705

706
                    if not self.loop.is_closed():
10✔
707
                        self.loop.call_soon_threadsafe(
10✔
708
                            self._report_result, future, result, exception
709
                        )
710

711
                self.queue.task_done()
10✔
712

713
    def stop(self, f: asyncio.Task | None = None) -> None:
10✔
714
        self.stopping = True
10✔
715
        self.queue.put_nowait(None)
10✔
716
        self.workers.discard(self)
10✔
717
        try:
10✔
718
            self.idle_workers.remove(self)
10✔
719
        except ValueError:
9✔
720
            pass
9✔
721

722

723
_threadpool_idle_workers: RunVar[Deque[WorkerThread]] = RunVar(
10✔
724
    "_threadpool_idle_workers"
725
)
726
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
10✔
727

728

729
class BlockingPortal(abc.BlockingPortal):
10✔
730
    def __new__(cls) -> BlockingPortal:
10✔
731
        return object.__new__(cls)
10✔
732

733
    def __init__(self) -> None:
10✔
734
        super().__init__()
10✔
735
        self._loop = get_running_loop()
10✔
736

737
    def _spawn_task_from_thread(
10✔
738
        self,
739
        func: Callable,
740
        args: tuple[Any, ...],
741
        kwargs: dict[str, Any],
742
        name: object,
743
        future: Future,
744
    ) -> None:
745
        AsyncIOBackend.run_sync_from_thread(
10✔
746
            partial(self._task_group.start_soon, name=name),
747
            (self._call_func, func, args, kwargs, future),
748
            self._loop,
749
        )
750

751

752
#
753
# Subprocesses
754
#
755

756

757
@dataclass(eq=False)
10✔
758
class StreamReaderWrapper(abc.ByteReceiveStream):
10✔
759
    _stream: asyncio.StreamReader
10✔
760

761
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
762
        data = await self._stream.read(max_bytes)
8✔
763
        if data:
8✔
764
            return data
8✔
765
        else:
766
            raise EndOfStream
8✔
767

768
    async def aclose(self) -> None:
10✔
769
        self._stream.feed_eof()
8✔
770

771

772
@dataclass(eq=False)
10✔
773
class StreamWriterWrapper(abc.ByteSendStream):
10✔
774
    _stream: asyncio.StreamWriter
10✔
775

776
    async def send(self, item: bytes) -> None:
10✔
777
        self._stream.write(item)
8✔
778
        await self._stream.drain()
8✔
779

780
    async def aclose(self) -> None:
10✔
781
        self._stream.close()
8✔
782

783

784
@dataclass(eq=False)
10✔
785
class Process(abc.Process):
10✔
786
    _process: asyncio.subprocess.Process
10✔
787
    _stdin: StreamWriterWrapper | None
10✔
788
    _stdout: StreamReaderWrapper | None
10✔
789
    _stderr: StreamReaderWrapper | None
10✔
790

791
    async def aclose(self) -> None:
10✔
792
        if self._stdin:
8✔
793
            await self._stdin.aclose()
8✔
794
        if self._stdout:
8✔
795
            await self._stdout.aclose()
8✔
796
        if self._stderr:
8✔
797
            await self._stderr.aclose()
8✔
798

799
        await self.wait()
8✔
800

801
    async def wait(self) -> int:
10✔
802
        return await self._process.wait()
8✔
803

804
    def terminate(self) -> None:
10✔
805
        self._process.terminate()
7✔
806

807
    def kill(self) -> None:
10✔
808
        self._process.kill()
8✔
809

810
    def send_signal(self, signal: int) -> None:
10✔
811
        self._process.send_signal(signal)
×
812

813
    @property
10✔
814
    def pid(self) -> int:
7✔
815
        return self._process.pid
×
816

817
    @property
10✔
818
    def returncode(self) -> int | None:
7✔
819
        return self._process.returncode
8✔
820

821
    @property
10✔
822
    def stdin(self) -> abc.ByteSendStream | None:
7✔
823
        return self._stdin
8✔
824

825
    @property
10✔
826
    def stdout(self) -> abc.ByteReceiveStream | None:
7✔
827
        return self._stdout
8✔
828

829
    @property
10✔
830
    def stderr(self) -> abc.ByteReceiveStream | None:
7✔
831
        return self._stderr
8✔
832

833

834
def _forcibly_shutdown_process_pool_on_exit(
10✔
835
    workers: set[Process], _task: object
836
) -> None:
837
    """
838
    Forcibly shuts down worker processes belonging to this event loop."""
839
    child_watcher: asyncio.AbstractChildWatcher | None
840
    try:
8✔
841
        child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
8✔
842
    except NotImplementedError:
8✔
843
        child_watcher = None
8✔
844

845
    # Close as much as possible (w/o async/await) to avoid warnings
846
    for process in workers:
8✔
847
        if process.returncode is None:
8✔
848
            continue
8✔
849

850
        process._stdin._stream._transport.close()  # type: ignore[union-attr]
×
851
        process._stdout._stream._transport.close()  # type: ignore[union-attr]
×
852
        process._stderr._stream._transport.close()  # type: ignore[union-attr]
×
853
        process.kill()
×
854
        if child_watcher:
×
855
            child_watcher.remove_child_handler(process.pid)
×
856

857

858
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
10✔
859
    """
860
    Shuts down worker processes belonging to this event loop.
861

862
    NOTE: this only works when the event loop was started using asyncio.run() or
863
    anyio.run().
864

865
    """
866
    process: abc.Process
867
    try:
8✔
868
        await sleep(math.inf)
8✔
869
    except asyncio.CancelledError:
8✔
870
        for process in workers:
8✔
871
            if process.returncode is None:
8✔
872
                process.kill()
8✔
873

874
        for process in workers:
8✔
875
            await process.aclose()
8✔
876

877

878
#
879
# Sockets and networking
880
#
881

882

883
class StreamProtocol(asyncio.Protocol):
10✔
884
    read_queue: Deque[bytes]
10✔
885
    read_event: asyncio.Event
10✔
886
    write_event: asyncio.Event
10✔
887
    exception: Exception | None = None
10✔
888

889
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
890
        self.read_queue = deque()
10✔
891
        self.read_event = asyncio.Event()
10✔
892
        self.write_event = asyncio.Event()
10✔
893
        self.write_event.set()
10✔
894
        cast(asyncio.Transport, transport).set_write_buffer_limits(0)
10✔
895

896
    def connection_lost(self, exc: Exception | None) -> None:
10✔
897
        if exc:
10✔
898
            self.exception = BrokenResourceError()
10✔
899
            self.exception.__cause__ = exc
10✔
900

901
        self.read_event.set()
10✔
902
        self.write_event.set()
10✔
903

904
    def data_received(self, data: bytes) -> None:
10✔
905
        self.read_queue.append(data)
10✔
906
        self.read_event.set()
10✔
907

908
    def eof_received(self) -> bool | None:
10✔
909
        self.read_event.set()
10✔
910
        return True
10✔
911

912
    def pause_writing(self) -> None:
10✔
913
        self.write_event = asyncio.Event()
10✔
914

915
    def resume_writing(self) -> None:
10✔
UNCOV
916
        self.write_event.set()
×
917

918

919
class DatagramProtocol(asyncio.DatagramProtocol):
10✔
920
    read_queue: Deque[tuple[bytes, IPSockAddrType]]
10✔
921
    read_event: asyncio.Event
10✔
922
    write_event: asyncio.Event
10✔
923
    exception: Exception | None = None
10✔
924

925
    def connection_made(self, transport: asyncio.BaseTransport) -> None:
10✔
926
        self.read_queue = deque(maxlen=100)  # arbitrary value
9✔
927
        self.read_event = asyncio.Event()
9✔
928
        self.write_event = asyncio.Event()
9✔
929
        self.write_event.set()
9✔
930

931
    def connection_lost(self, exc: Exception | None) -> None:
10✔
932
        self.read_event.set()
9✔
933
        self.write_event.set()
9✔
934

935
    def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
10✔
936
        addr = convert_ipv6_sockaddr(addr)
9✔
937
        self.read_queue.append((data, addr))
9✔
938
        self.read_event.set()
9✔
939

940
    def error_received(self, exc: Exception) -> None:
10✔
941
        self.exception = exc
×
942

943
    def pause_writing(self) -> None:
10✔
944
        self.write_event.clear()
×
945

946
    def resume_writing(self) -> None:
10✔
947
        self.write_event.set()
×
948

949

950
class SocketStream(abc.SocketStream):
10✔
951
    def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
10✔
952
        self._transport = transport
10✔
953
        self._protocol = protocol
10✔
954
        self._receive_guard = ResourceGuard("reading from")
10✔
955
        self._send_guard = ResourceGuard("writing to")
10✔
956
        self._closed = False
10✔
957

958
    @property
10✔
959
    def _raw_socket(self) -> socket.socket:
7✔
960
        return self._transport.get_extra_info("socket")
10✔
961

962
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
963
        with self._receive_guard:
10✔
964
            await AsyncIOBackend.checkpoint()
10✔
965

966
            if (
10✔
967
                not self._protocol.read_event.is_set()
968
                and not self._transport.is_closing()
969
            ):
970
                self._transport.resume_reading()
10✔
971
                await self._protocol.read_event.wait()
10✔
972
                self._transport.pause_reading()
10✔
973

974
            try:
10✔
975
                chunk = self._protocol.read_queue.popleft()
10✔
976
            except IndexError:
10✔
977
                if self._closed:
10✔
978
                    raise ClosedResourceError from None
10✔
979
                elif self._protocol.exception:
10✔
980
                    raise self._protocol.exception
10✔
981
                else:
982
                    raise EndOfStream from None
10✔
983

984
            if len(chunk) > max_bytes:
10✔
985
                # Split the oversized chunk
986
                chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
9✔
987
                self._protocol.read_queue.appendleft(leftover)
9✔
988

989
            # If the read queue is empty, clear the flag so that the next call will
990
            # block until data is available
991
            if not self._protocol.read_queue:
10✔
992
                self._protocol.read_event.clear()
10✔
993

994
        return chunk
10✔
995

996
    async def send(self, item: bytes) -> None:
10✔
997
        with self._send_guard:
10✔
998
            await AsyncIOBackend.checkpoint()
10✔
999

1000
            if self._closed:
10✔
1001
                raise ClosedResourceError
10✔
1002
            elif self._protocol.exception is not None:
10✔
1003
                raise self._protocol.exception
10✔
1004

1005
            try:
10✔
1006
                self._transport.write(item)
10✔
1007
            except RuntimeError as exc:
×
1008
                if self._transport.is_closing():
×
1009
                    raise BrokenResourceError from exc
×
1010
                else:
1011
                    raise
×
1012

1013
            await self._protocol.write_event.wait()
10✔
1014

1015
    async def send_eof(self) -> None:
10✔
1016
        try:
10✔
1017
            self._transport.write_eof()
10✔
1018
        except OSError:
×
1019
            pass
×
1020

1021
    async def aclose(self) -> None:
10✔
1022
        if not self._transport.is_closing():
10✔
1023
            self._closed = True
10✔
1024
            try:
10✔
1025
                self._transport.write_eof()
10✔
1026
            except OSError:
4✔
1027
                pass
4✔
1028

1029
            self._transport.close()
10✔
1030
            await sleep(0)
10✔
1031
            self._transport.abort()
10✔
1032

1033

1034
class _RawSocketMixin:
10✔
1035
    _receive_future: asyncio.Future | None = None
10✔
1036
    _send_future: asyncio.Future | None = None
10✔
1037
    _closing = False
10✔
1038

1039
    def __init__(self, raw_socket: socket.socket):
10✔
1040
        self.__raw_socket = raw_socket
7✔
1041
        self._receive_guard = ResourceGuard("reading from")
7✔
1042
        self._send_guard = ResourceGuard("writing to")
7✔
1043

1044
    @property
10✔
1045
    def _raw_socket(self) -> socket.socket:
7✔
1046
        return self.__raw_socket
7✔
1047

1048
    def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1049
        def callback(f: object) -> None:
7✔
1050
            del self._receive_future
7✔
1051
            loop.remove_reader(self.__raw_socket)
7✔
1052

1053
        f = self._receive_future = asyncio.Future()
7✔
1054
        loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1055
        f.add_done_callback(callback)
7✔
1056
        return f
7✔
1057

1058
    def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
10✔
1059
        def callback(f: object) -> None:
7✔
1060
            del self._send_future
7✔
1061
            loop.remove_writer(self.__raw_socket)
7✔
1062

1063
        f = self._send_future = asyncio.Future()
7✔
1064
        loop.add_writer(self.__raw_socket, f.set_result, None)
7✔
1065
        f.add_done_callback(callback)
7✔
1066
        return f
7✔
1067

1068
    async def aclose(self) -> None:
10✔
1069
        if not self._closing:
7✔
1070
            self._closing = True
7✔
1071
            if self.__raw_socket.fileno() != -1:
7✔
1072
                self.__raw_socket.close()
7✔
1073

1074
            if self._receive_future:
7✔
1075
                self._receive_future.set_result(None)
7✔
1076
            if self._send_future:
7✔
1077
                self._send_future.set_result(None)
×
1078

1079

1080
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
10✔
1081
    async def send_eof(self) -> None:
10✔
1082
        with self._send_guard:
7✔
1083
            self._raw_socket.shutdown(socket.SHUT_WR)
7✔
1084

1085
    async def receive(self, max_bytes: int = 65536) -> bytes:
10✔
1086
        loop = get_running_loop()
7✔
1087
        await AsyncIOBackend.checkpoint()
7✔
1088
        with self._receive_guard:
7✔
1089
            while True:
5✔
1090
                try:
7✔
1091
                    data = self._raw_socket.recv(max_bytes)
7✔
1092
                except BlockingIOError:
7✔
1093
                    await self._wait_until_readable(loop)
7✔
1094
                except OSError as exc:
7✔
1095
                    if self._closing:
7✔
1096
                        raise ClosedResourceError from None
7✔
1097
                    else:
1098
                        raise BrokenResourceError from exc
1✔
1099
                else:
1100
                    if not data:
7✔
1101
                        raise EndOfStream
7✔
1102

1103
                    return data
7✔
1104

1105
    async def send(self, item: bytes) -> None:
10✔
1106
        loop = get_running_loop()
7✔
1107
        await AsyncIOBackend.checkpoint()
7✔
1108
        with self._send_guard:
7✔
1109
            view = memoryview(item)
7✔
1110
            while view:
7✔
1111
                try:
7✔
1112
                    bytes_sent = self._raw_socket.send(item)
7✔
1113
                except BlockingIOError:
7✔
1114
                    await self._wait_until_writable(loop)
7✔
1115
                except OSError as exc:
7✔
1116
                    if self._closing:
7✔
1117
                        raise ClosedResourceError from None
7✔
1118
                    else:
1119
                        raise BrokenResourceError from exc
1✔
1120
                else:
1121
                    view = view[bytes_sent:]
7✔
1122

1123
    async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
10✔
1124
        if not isinstance(msglen, int) or msglen < 0:
7✔
1125
            raise ValueError("msglen must be a non-negative integer")
7✔
1126
        if not isinstance(maxfds, int) or maxfds < 1:
7✔
1127
            raise ValueError("maxfds must be a positive integer")
7✔
1128

1129
        loop = get_running_loop()
7✔
1130
        fds = array.array("i")
7✔
1131
        await AsyncIOBackend.checkpoint()
7✔
1132
        with self._receive_guard:
7✔
1133
            while True:
5✔
1134
                try:
7✔
1135
                    message, ancdata, flags, addr = self._raw_socket.recvmsg(
7✔
1136
                        msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
1137
                    )
1138
                except BlockingIOError:
7✔
1139
                    await self._wait_until_readable(loop)
7✔
1140
                except OSError as exc:
×
1141
                    if self._closing:
×
1142
                        raise ClosedResourceError from None
×
1143
                    else:
1144
                        raise BrokenResourceError from exc
×
1145
                else:
1146
                    if not message and not ancdata:
7✔
1147
                        raise EndOfStream
×
1148

1149
                    break
5✔
1150

1151
        for cmsg_level, cmsg_type, cmsg_data in ancdata:
7✔
1152
            if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
7✔
1153
                raise RuntimeError(
×
1154
                    f"Received unexpected ancillary data; message = {message!r}, "
1155
                    f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
1156
                )
1157

1158
            fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
7✔
1159

1160
        return message, list(fds)
7✔
1161

1162
    async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
10✔
1163
        if not message:
7✔
1164
            raise ValueError("message must not be empty")
7✔
1165
        if not fds:
7✔
1166
            raise ValueError("fds must not be empty")
7✔
1167

1168
        loop = get_running_loop()
7✔
1169
        filenos: list[int] = []
7✔
1170
        for fd in fds:
7✔
1171
            if isinstance(fd, int):
7✔
1172
                filenos.append(fd)
×
1173
            elif isinstance(fd, IOBase):
7✔
1174
                filenos.append(fd.fileno())
7✔
1175

1176
        fdarray = array.array("i", filenos)
7✔
1177
        await AsyncIOBackend.checkpoint()
7✔
1178
        with self._send_guard:
7✔
1179
            while True:
5✔
1180
                try:
7✔
1181
                    # The ignore can be removed after mypy picks up
1182
                    # https://github.com/python/typeshed/pull/5545
1183
                    self._raw_socket.sendmsg(
7✔
1184
                        [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
1185
                    )
1186
                    break
7✔
1187
                except BlockingIOError:
×
1188
                    await self._wait_until_writable(loop)
×
1189
                except OSError as exc:
×
1190
                    if self._closing:
×
1191
                        raise ClosedResourceError from None
×
1192
                    else:
1193
                        raise BrokenResourceError from exc
×
1194

1195

1196
class TCPSocketListener(abc.SocketListener):
10✔
1197
    _accept_scope: CancelScope | None = None
10✔
1198
    _closed = False
10✔
1199

1200
    def __init__(self, raw_socket: socket.socket):
10✔
1201
        self.__raw_socket = raw_socket
10✔
1202
        self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
10✔
1203
        self._accept_guard = ResourceGuard("accepting connections from")
10✔
1204

1205
    @property
10✔
1206
    def _raw_socket(self) -> socket.socket:
7✔
1207
        return self.__raw_socket
10✔
1208

1209
    async def accept(self) -> abc.SocketStream:
10✔
1210
        if self._closed:
10✔
1211
            raise ClosedResourceError
10✔
1212

1213
        with self._accept_guard:
10✔
1214
            await AsyncIOBackend.checkpoint()
10✔
1215
            with CancelScope() as self._accept_scope:
10✔
1216
                try:
10✔
1217
                    client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
10✔
1218
                except asyncio.CancelledError:
9✔
1219
                    # Workaround for https://bugs.python.org/issue41317
1220
                    try:
9✔
1221
                        self._loop.remove_reader(self._raw_socket)
9✔
1222
                    except (ValueError, NotImplementedError):
1✔
1223
                        pass
1✔
1224

1225
                    if self._closed:
9✔
1226
                        raise ClosedResourceError from None
9✔
1227

1228
                    raise
9✔
1229
                finally:
1230
                    self._accept_scope = None
10✔
1231

1232
        client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
10✔
1233
        transport, protocol = await self._loop.connect_accepted_socket(
10✔
1234
            StreamProtocol, client_sock
1235
        )
1236
        return SocketStream(transport, protocol)
10✔
1237

1238
    async def aclose(self) -> None:
10✔
1239
        if self._closed:
10✔
1240
            return
10✔
1241

1242
        self._closed = True
10✔
1243
        if self._accept_scope:
10✔
1244
            # Workaround for https://bugs.python.org/issue41317
1245
            try:
10✔
1246
                self._loop.remove_reader(self._raw_socket)
10✔
1247
            except (ValueError, NotImplementedError):
1✔
1248
                pass
1✔
1249

1250
            self._accept_scope.cancel()
9✔
1251
            await sleep(0)
9✔
1252

1253
        self._raw_socket.close()
10✔
1254

1255

1256
class UNIXSocketListener(abc.SocketListener):
10✔
1257
    def __init__(self, raw_socket: socket.socket):
10✔
1258
        self.__raw_socket = raw_socket
7✔
1259
        self._loop = get_running_loop()
7✔
1260
        self._accept_guard = ResourceGuard("accepting connections from")
7✔
1261
        self._closed = False
7✔
1262

1263
    async def accept(self) -> abc.SocketStream:
10✔
1264
        await AsyncIOBackend.checkpoint()
7✔
1265
        with self._accept_guard:
7✔
1266
            while True:
5✔
1267
                try:
7✔
1268
                    client_sock, _ = self.__raw_socket.accept()
7✔
1269
                    client_sock.setblocking(False)
7✔
1270
                    return UNIXSocketStream(client_sock)
7✔
1271
                except BlockingIOError:
7✔
1272
                    f: asyncio.Future = asyncio.Future()
7✔
1273
                    self._loop.add_reader(self.__raw_socket, f.set_result, None)
7✔
1274
                    f.add_done_callback(
7✔
1275
                        lambda _: self._loop.remove_reader(self.__raw_socket)
1276
                    )
1277
                    await f
7✔
1278
                except OSError as exc:
×
1279
                    if self._closed:
×
1280
                        raise ClosedResourceError from None
×
1281
                    else:
1282
                        raise BrokenResourceError from exc
1✔
1283

1284
    async def aclose(self) -> None:
10✔
1285
        self._closed = True
7✔
1286
        self.__raw_socket.close()
7✔
1287

1288
    @property
10✔
1289
    def _raw_socket(self) -> socket.socket:
7✔
1290
        return self.__raw_socket
7✔
1291

1292

1293
class UDPSocket(abc.UDPSocket):
10✔
1294
    def __init__(
10✔
1295
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1296
    ):
1297
        self._transport = transport
9✔
1298
        self._protocol = protocol
9✔
1299
        self._receive_guard = ResourceGuard("reading from")
9✔
1300
        self._send_guard = ResourceGuard("writing to")
9✔
1301
        self._closed = False
9✔
1302

1303
    @property
10✔
1304
    def _raw_socket(self) -> socket.socket:
7✔
1305
        return self._transport.get_extra_info("socket")
9✔
1306

1307
    async def aclose(self) -> None:
10✔
1308
        if not self._transport.is_closing():
9✔
1309
            self._closed = True
9✔
1310
            self._transport.close()
9✔
1311

1312
    async def receive(self) -> tuple[bytes, IPSockAddrType]:
10✔
1313
        with self._receive_guard:
9✔
1314
            await AsyncIOBackend.checkpoint()
9✔
1315

1316
            # If the buffer is empty, ask for more data
1317
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1318
                self._protocol.read_event.clear()
9✔
1319
                await self._protocol.read_event.wait()
9✔
1320

1321
            try:
9✔
1322
                return self._protocol.read_queue.popleft()
9✔
1323
            except IndexError:
9✔
1324
                if self._closed:
9✔
1325
                    raise ClosedResourceError from None
9✔
1326
                else:
1327
                    raise BrokenResourceError from None
1✔
1328

1329
    async def send(self, item: UDPPacketType) -> None:
10✔
1330
        with self._send_guard:
9✔
1331
            await AsyncIOBackend.checkpoint()
9✔
1332
            await self._protocol.write_event.wait()
9✔
1333
            if self._closed:
9✔
1334
                raise ClosedResourceError
9✔
1335
            elif self._transport.is_closing():
9✔
1336
                raise BrokenResourceError
×
1337
            else:
1338
                self._transport.sendto(*item)
9✔
1339

1340

1341
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
10✔
1342
    def __init__(
10✔
1343
        self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
1344
    ):
1345
        self._transport = transport
9✔
1346
        self._protocol = protocol
9✔
1347
        self._receive_guard = ResourceGuard("reading from")
9✔
1348
        self._send_guard = ResourceGuard("writing to")
9✔
1349
        self._closed = False
9✔
1350

1351
    @property
10✔
1352
    def _raw_socket(self) -> socket.socket:
7✔
1353
        return self._transport.get_extra_info("socket")
9✔
1354

1355
    async def aclose(self) -> None:
10✔
1356
        if not self._transport.is_closing():
9✔
1357
            self._closed = True
9✔
1358
            self._transport.close()
9✔
1359

1360
    async def receive(self) -> bytes:
10✔
1361
        with self._receive_guard:
9✔
1362
            await AsyncIOBackend.checkpoint()
9✔
1363

1364
            # If the buffer is empty, ask for more data
1365
            if not self._protocol.read_queue and not self._transport.is_closing():
9✔
1366
                self._protocol.read_event.clear()
9✔
1367
                await self._protocol.read_event.wait()
9✔
1368

1369
            try:
9✔
1370
                packet = self._protocol.read_queue.popleft()
9✔
1371
            except IndexError:
9✔
1372
                if self._closed:
9✔
1373
                    raise ClosedResourceError from None
9✔
1374
                else:
1375
                    raise BrokenResourceError from None
×
1376

1377
            return packet[0]
9✔
1378

1379
    async def send(self, item: bytes) -> None:
10✔
1380
        with self._send_guard:
9✔
1381
            await AsyncIOBackend.checkpoint()
9✔
1382
            await self._protocol.write_event.wait()
9✔
1383
            if self._closed:
9✔
1384
                raise ClosedResourceError
9✔
1385
            elif self._transport.is_closing():
9✔
1386
                raise BrokenResourceError
×
1387
            else:
1388
                self._transport.sendto(item)
9✔
1389

1390

1391
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
10✔
1392
    async def receive(self) -> UNIXDatagramPacketType:
10✔
1393
        loop = get_running_loop()
7✔
1394
        await AsyncIOBackend.checkpoint()
7✔
1395
        with self._receive_guard:
7✔
1396
            while True:
5✔
1397
                try:
7✔
1398
                    data = self._raw_socket.recvfrom(65536)
7✔
1399
                except BlockingIOError:
7✔
1400
                    await self._wait_until_readable(loop)
7✔
1401
                except OSError as exc:
7✔
1402
                    if self._closing:
7✔
1403
                        raise ClosedResourceError from None
7✔
1404
                    else:
1405
                        raise BrokenResourceError from exc
1✔
1406
                else:
1407
                    return data
7✔
1408

1409
    async def send(self, item: UNIXDatagramPacketType) -> None:
10✔
1410
        loop = get_running_loop()
7✔
1411
        await AsyncIOBackend.checkpoint()
7✔
1412
        with self._send_guard:
7✔
1413
            while True:
5✔
1414
                try:
7✔
1415
                    self._raw_socket.sendto(*item)
7✔
1416
                except BlockingIOError:
7✔
1417
                    await self._wait_until_writable(loop)
×
1418
                except OSError as exc:
7✔
1419
                    if self._closing:
7✔
1420
                        raise ClosedResourceError from None
7✔
1421
                    else:
1422
                        raise BrokenResourceError from exc
1✔
1423
                else:
1424
                    return
7✔
1425

1426

1427
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
10✔
1428
    async def receive(self) -> bytes:
10✔
1429
        loop = get_running_loop()
7✔
1430
        await AsyncIOBackend.checkpoint()
7✔
1431
        with self._receive_guard:
7✔
1432
            while True:
5✔
1433
                try:
7✔
1434
                    data = self._raw_socket.recv(65536)
7✔
1435
                except BlockingIOError:
7✔
1436
                    await self._wait_until_readable(loop)
7✔
1437
                except OSError as exc:
7✔
1438
                    if self._closing:
7✔
1439
                        raise ClosedResourceError from None
7✔
1440
                    else:
1441
                        raise BrokenResourceError from exc
1✔
1442
                else:
1443
                    return data
7✔
1444

1445
    async def send(self, item: bytes) -> None:
10✔
1446
        loop = get_running_loop()
7✔
1447
        await AsyncIOBackend.checkpoint()
7✔
1448
        with self._send_guard:
7✔
1449
            while True:
5✔
1450
                try:
7✔
1451
                    self._raw_socket.send(item)
7✔
1452
                except BlockingIOError:
7✔
1453
                    await self._wait_until_writable(loop)
×
1454
                except OSError as exc:
7✔
1455
                    if self._closing:
7✔
1456
                        raise ClosedResourceError from None
7✔
1457
                    else:
1458
                        raise BrokenResourceError from exc
1✔
1459
                else:
1460
                    return
7✔
1461

1462

1463
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
10✔
1464
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
10✔
1465

1466

1467
#
1468
# Synchronization
1469
#
1470

1471

1472
class Event(BaseEvent):
10✔
1473
    def __new__(cls) -> Event:
10✔
1474
        return object.__new__(cls)
10✔
1475

1476
    def __init__(self) -> None:
10✔
1477
        self._event = asyncio.Event()
10✔
1478

1479
    def set(self) -> None:
10✔
1480
        self._event.set()
10✔
1481

1482
    def is_set(self) -> bool:
10✔
1483
        return self._event.is_set()
9✔
1484

1485
    async def wait(self) -> None:
10✔
1486
        if await self._event.wait():
10✔
1487
            await AsyncIOBackend.checkpoint()
10✔
1488

1489
    def statistics(self) -> EventStatistics:
10✔
1490
        return EventStatistics(len(self._event._waiters))  # type: ignore[attr-defined]
9✔
1491

1492

1493
class CapacityLimiter(BaseCapacityLimiter):
10✔
1494
    _total_tokens: float = 0
10✔
1495

1496
    def __new__(cls, total_tokens: float) -> CapacityLimiter:
10✔
1497
        return object.__new__(cls)
10✔
1498

1499
    def __init__(self, total_tokens: float):
10✔
1500
        self._borrowers: set[Any] = set()
10✔
1501
        self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
10✔
1502
        self.total_tokens = total_tokens
10✔
1503

1504
    async def __aenter__(self) -> None:
10✔
1505
        await self.acquire()
10✔
1506

1507
    async def __aexit__(
10✔
1508
        self,
1509
        exc_type: type[BaseException] | None,
1510
        exc_val: BaseException | None,
1511
        exc_tb: TracebackType | None,
1512
    ) -> None:
1513
        self.release()
10✔
1514

1515
    @property
10✔
1516
    def total_tokens(self) -> float:
7✔
1517
        return self._total_tokens
9✔
1518

1519
    @total_tokens.setter
10✔
1520
    def total_tokens(self, value: float) -> None:
7✔
1521
        if not isinstance(value, int) and not math.isinf(value):
10✔
1522
            raise TypeError("total_tokens must be an int or math.inf")
9✔
1523
        if value < 1:
10✔
1524
            raise ValueError("total_tokens must be >= 1")
9✔
1525

1526
        old_value = self._total_tokens
10✔
1527
        self._total_tokens = value
10✔
1528
        events = []
10✔
1529
        for event in self._wait_queue.values():
10✔
1530
            if value <= old_value:
9✔
1531
                break
×
1532

1533
            if not event.is_set():
9✔
1534
                events.append(event)
9✔
1535
                old_value += 1
9✔
1536

1537
        for event in events:
10✔
1538
            event.set()
9✔
1539

1540
    @property
10✔
1541
    def borrowed_tokens(self) -> int:
7✔
1542
        return len(self._borrowers)
9✔
1543

1544
    @property
10✔
1545
    def available_tokens(self) -> float:
7✔
1546
        return self._total_tokens - len(self._borrowers)
9✔
1547

1548
    def acquire_nowait(self) -> None:
10✔
1549
        self.acquire_on_behalf_of_nowait(current_task())
×
1550

1551
    def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
10✔
1552
        if borrower in self._borrowers:
10✔
1553
            raise RuntimeError(
9✔
1554
                "this borrower is already holding one of this CapacityLimiter's "
1555
                "tokens"
1556
            )
1557

1558
        if self._wait_queue or len(self._borrowers) >= self._total_tokens:
10✔
1559
            raise WouldBlock
9✔
1560

1561
        self._borrowers.add(borrower)
10✔
1562

1563
    async def acquire(self) -> None:
10✔
1564
        return await self.acquire_on_behalf_of(current_task())
10✔
1565

1566
    async def acquire_on_behalf_of(self, borrower: object) -> None:
10✔
1567
        await AsyncIOBackend.checkpoint_if_cancelled()
10✔
1568
        try:
10✔
1569
            self.acquire_on_behalf_of_nowait(borrower)
10✔
1570
        except WouldBlock:
9✔
1571
            event = asyncio.Event()
9✔
1572
            self._wait_queue[borrower] = event
9✔
1573
            try:
9✔
1574
                await event.wait()
9✔
1575
            except BaseException:
×
1576
                self._wait_queue.pop(borrower, None)
×
1577
                raise
×
1578

1579
            self._borrowers.add(borrower)
9✔
1580
        else:
1581
            try:
10✔
1582
                await AsyncIOBackend.cancel_shielded_checkpoint()
10✔
1583
            except BaseException:
9✔
1584
                self.release()
9✔
1585
                raise
9✔
1586

1587
    def release(self) -> None:
10✔
1588
        self.release_on_behalf_of(current_task())
10✔
1589

1590
    def release_on_behalf_of(self, borrower: object) -> None:
10✔
1591
        try:
10✔
1592
            self._borrowers.remove(borrower)
10✔
1593
        except KeyError:
9✔
1594
            raise RuntimeError(
9✔
1595
                "this borrower isn't holding any of this CapacityLimiter's " "tokens"
1596
            ) from None
1597

1598
        # Notify the next task in line if this limiter has free capacity now
1599
        if self._wait_queue and len(self._borrowers) < self._total_tokens:
10✔
1600
            event = self._wait_queue.popitem(last=False)[1]
9✔
1601
            event.set()
9✔
1602

1603
    def statistics(self) -> CapacityLimiterStatistics:
10✔
1604
        return CapacityLimiterStatistics(
9✔
1605
            self.borrowed_tokens,
1606
            self.total_tokens,
1607
            tuple(self._borrowers),
1608
            len(self._wait_queue),
1609
        )
1610

1611

1612
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
10✔
1613

1614

1615
#
1616
# Operating system signals
1617
#
1618

1619

1620
class _SignalReceiver:
10✔
1621
    def __init__(self, signals: tuple[Signals, ...]):
10✔
1622
        self._signals = signals
8✔
1623
        self._loop = get_running_loop()
8✔
1624
        self._signal_queue: Deque[Signals] = deque()
8✔
1625
        self._future: asyncio.Future = asyncio.Future()
8✔
1626
        self._handled_signals: set[Signals] = set()
8✔
1627

1628
    def _deliver(self, signum: Signals) -> None:
10✔
1629
        self._signal_queue.append(signum)
8✔
1630
        if not self._future.done():
8✔
1631
            self._future.set_result(None)
8✔
1632

1633
    def __enter__(self) -> _SignalReceiver:
10✔
1634
        for sig in set(self._signals):
8✔
1635
            self._loop.add_signal_handler(sig, self._deliver, sig)
8✔
1636
            self._handled_signals.add(sig)
8✔
1637

1638
        return self
8✔
1639

1640
    def __exit__(
10✔
1641
        self,
1642
        exc_type: type[BaseException] | None,
1643
        exc_val: BaseException | None,
1644
        exc_tb: TracebackType | None,
1645
    ) -> bool | None:
1646
        for sig in self._handled_signals:
8✔
1647
            self._loop.remove_signal_handler(sig)
8✔
1648
        return None
8✔
1649

1650
    def __aiter__(self) -> _SignalReceiver:
10✔
1651
        return self
8✔
1652

1653
    async def __anext__(self) -> Signals:
10✔
1654
        await AsyncIOBackend.checkpoint()
8✔
1655
        if not self._signal_queue:
8✔
1656
            self._future = asyncio.Future()
×
1657
            await self._future
×
1658

1659
        return self._signal_queue.popleft()
8✔
1660

1661

1662
#
1663
# Testing and debugging
1664
#
1665

1666

1667
def _create_task_info(task: asyncio.Task) -> TaskInfo:
10✔
1668
    task_state = _task_states.get(task)
10✔
1669
    if task_state is None:
10✔
1670
        name = task.get_name() if _native_task_names else None
10✔
1671
        parent_id = None
10✔
1672
    else:
1673
        name = task_state.name
10✔
1674
        parent_id = task_state.parent_id
10✔
1675

1676
    return TaskInfo(id(task), parent_id, name, get_coro(task))
10✔
1677

1678

1679
async def _shutdown_default_executor(loop: asyncio.BaseEventLoop) -> None:
10✔
1680
    """Schedule the shutdown of the default executor.
1681
    BaseEventLoop.shutdown_default_executor was introduced in Python 3.9.
1682
    This function is an adapted version of the method from Python 3.11.
1683
    It's used in TestRunner.close only if python < 3.9.
1684
    """
1685

1686
    def _do_shutdown(
4✔
1687
        loop_: asyncio.BaseEventLoop, future: asyncio.futures.Future
1688
    ) -> None:
1689
        try:
4✔
1690
            loop_._default_executor.shutdown(wait=True)  # type: ignore[attr-defined]
4✔
1691
            loop_.call_soon_threadsafe(future.set_result, None)
4✔
1692
        except Exception as ex:
×
1693
            loop_.call_soon_threadsafe(future.set_exception, ex)
×
1694

1695
    if loop._default_executor is None:  # type: ignore[attr-defined]
4✔
1696
        return
4✔
1697
    future = loop.create_future()
4✔
1698
    thread = threading.Thread(
4✔
1699
        target=_do_shutdown,
1700
        args=(
1701
            loop,
1702
            future,
1703
        ),
1704
    )
1705
    thread.start()
4✔
1706
    try:
4✔
1707
        await future
4✔
1708
    finally:
1709
        thread.join()
4✔
1710

1711

1712
class TestRunner(abc.TestRunner):
10✔
1713
    _send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
10✔
1714

1715
    def __init__(
10✔
1716
        self,
1717
        debug: bool = False,
1718
        use_uvloop: bool = False,
1719
        policy: asyncio.AbstractEventLoopPolicy | None = None,
1720
    ):
1721
        self._exceptions: list[BaseException] = []
10✔
1722
        _maybe_set_event_loop_policy(policy, use_uvloop)
10✔
1723
        self._loop = asyncio.new_event_loop()
10✔
1724
        self._loop.set_debug(debug)
10✔
1725
        self._loop.set_exception_handler(self._exception_handler)
10✔
1726
        self._runner_task: asyncio.Task | None = None
10✔
1727
        asyncio.set_event_loop(self._loop)
10✔
1728

1729
    def _cancel_all_tasks(self) -> None:
10✔
1730
        to_cancel = all_tasks(self._loop)
10✔
1731
        if not to_cancel:
10✔
1732
            return
7✔
1733

1734
        for task in to_cancel:
10✔
1735
            task.cancel()
10✔
1736

1737
        self._loop.run_until_complete(
10✔
1738
            asyncio.gather(*to_cancel, return_exceptions=True)
1739
        )
1740

1741
        for task in to_cancel:
10✔
1742
            if task.cancelled():
10✔
1743
                continue
10✔
1744
            if task.exception() is not None:
8✔
1745
                raise cast(BaseException, task.exception())
×
1746

1747
    def _exception_handler(
10✔
1748
        self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
1749
    ) -> None:
1750
        if isinstance(context.get("exception"), Exception):
10✔
1751
            self._exceptions.append(context["exception"])
10✔
1752
        else:
1753
            loop.default_exception_handler(context)
10✔
1754

1755
    def _raise_async_exceptions(self) -> None:
10✔
1756
        # Re-raise any exceptions raised in asynchronous callbacks
1757
        if self._exceptions:
10✔
1758
            exceptions, self._exceptions = self._exceptions, []
10✔
1759
            if len(exceptions) == 1:
10✔
1760
                raise exceptions[0]
10✔
1761
            elif exceptions:
×
1762
                raise BaseExceptionGroup(
×
1763
                    "Multiple exceptions occurred in asynchronous callbacks", exceptions
1764
                )
1765

1766
    @staticmethod
10✔
1767
    async def _run_tests_and_fixtures(
7✔
1768
        receive_stream: MemoryObjectReceiveStream[
1769
            tuple[Coroutine[Any, Any, T_Retval], Future[T_Retval]]
1770
        ],
1771
    ) -> None:
1772
        with receive_stream:
10✔
1773
            async for coro, future in receive_stream:
10✔
1774
                try:
10✔
1775
                    retval = await coro
10✔
1776
                except BaseException as exc:
10✔
1777
                    if not future.cancelled():
10✔
1778
                        future.set_exception(exc)
10✔
1779
                else:
1780
                    if not future.cancelled():
10✔
1781
                        future.set_result(retval)
10✔
1782

1783
    async def _call_in_runner_task(
10✔
1784
        self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object
1785
    ) -> T_Retval:
1786
        if not self._runner_task:
10✔
1787
            self._send_stream, receive_stream = create_memory_object_stream(1)
10✔
1788
            self._runner_task = self._loop.create_task(
10✔
1789
                self._run_tests_and_fixtures(receive_stream)
1790
            )
1791

1792
        coro = func(*args, **kwargs)
10✔
1793
        future: asyncio.Future[T_Retval] = self._loop.create_future()
10✔
1794
        self._send_stream.send_nowait((coro, future))
10✔
1795
        return await future
10✔
1796

1797
    def close(self) -> None:
10✔
1798
        try:
10✔
1799
            if self._runner_task is not None:
10✔
1800
                self._runner_task = None
10✔
1801
                self._loop.run_until_complete(self._send_stream.aclose())
10✔
1802
                del self._send_stream
10✔
1803

1804
            self._cancel_all_tasks()
10✔
1805
            self._loop.run_until_complete(self._loop.shutdown_asyncgens())
10✔
1806
            if hasattr(self._loop, "shutdown_default_executor"):
10✔
1807
                # asyncio in Python >= 3.9 or uvloop >= 0.15.0
1808
                self._loop.run_until_complete(self._loop.shutdown_default_executor())
9✔
1809
            elif isinstance(self._loop, asyncio.BaseEventLoop) and hasattr(
4✔
1810
                self._loop, "_default_executor"
1811
            ):
1812
                # asyncio in Python < 3.9
1813
                self._loop.run_until_complete(_shutdown_default_executor(self._loop))
4✔
1814
        finally:
1815
            asyncio.set_event_loop(None)
10✔
1816
            self._loop.close()
10✔
1817

1818
    def run_asyncgen_fixture(
10✔
1819
        self,
1820
        fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
1821
        kwargs: dict[str, Any],
1822
    ) -> Iterable[T_Retval]:
1823
        asyncgen = fixture_func(**kwargs)
10✔
1824
        fixturevalue: T_Retval = self._loop.run_until_complete(
10✔
1825
            self._call_in_runner_task(asyncgen.asend, None)
1826
        )
1827
        self._raise_async_exceptions()
10✔
1828

1829
        yield fixturevalue
10✔
1830

1831
        try:
10✔
1832
            self._loop.run_until_complete(
10✔
1833
                self._call_in_runner_task(asyncgen.asend, None)
1834
            )
1835
        except StopAsyncIteration:
10✔
1836
            self._raise_async_exceptions()
10✔
1837
        else:
1838
            self._loop.run_until_complete(asyncgen.aclose())
×
1839
            raise RuntimeError("Async generator fixture did not stop")
×
1840

1841
    def run_fixture(
10✔
1842
        self,
1843
        fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
1844
        kwargs: dict[str, Any],
1845
    ) -> T_Retval:
1846
        retval = self._loop.run_until_complete(
10✔
1847
            self._call_in_runner_task(fixture_func, **kwargs)
1848
        )
1849
        self._raise_async_exceptions()
10✔
1850
        return retval
10✔
1851

1852
    def run_test(
10✔
1853
        self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
1854
    ) -> None:
1855
        try:
10✔
1856
            self._loop.run_until_complete(
10✔
1857
                self._call_in_runner_task(test_func, **kwargs)
1858
            )
1859
        except Exception as exc:
10✔
1860
            self._exceptions.append(exc)
×
1861

1862
        self._raise_async_exceptions()
10✔
1863

1864

1865
class AsyncIOBackend(AsyncBackend):
10✔
1866
    @classmethod
10✔
1867
    def run(
7✔
1868
        cls,
1869
        func: Callable[..., Awaitable[T_Retval]],
1870
        args: tuple,
1871
        kwargs: dict[str, Any],
1872
        options: dict[str, Any],
1873
    ) -> T_Retval:
1874
        @wraps(func)
10✔
1875
        async def wrapper() -> T_Retval:
7✔
1876
            task = cast(asyncio.Task, current_task())
10✔
1877
            task_state = TaskState(None, get_callable_name(func), None)
10✔
1878
            _task_states[task] = task_state
10✔
1879
            if _native_task_names:
10✔
1880
                task.set_name(task_state.name)
7✔
1881

1882
            try:
10✔
1883
                return await func(*args)
10✔
1884
            finally:
1885
                del _task_states[task]
10✔
1886

1887
        debug = options.get("debug", False)
10✔
1888
        policy = options.get("policy", None)
10✔
1889
        use_uvloop = options.get("use_uvloop", False)
10✔
1890
        _maybe_set_event_loop_policy(policy, use_uvloop)
10✔
1891
        return native_run(wrapper(), debug=debug)
10✔
1892

1893
    @classmethod
10✔
1894
    def current_token(cls) -> object:
7✔
1895
        return get_running_loop()
10✔
1896

1897
    @classmethod
10✔
1898
    def current_time(cls) -> float:
7✔
1899
        return get_running_loop().time()
10✔
1900

1901
    @classmethod
10✔
1902
    def cancelled_exception_class(cls) -> type[BaseException]:
7✔
1903
        return CancelledError
10✔
1904

1905
    @classmethod
10✔
1906
    async def checkpoint(cls) -> None:
7✔
1907
        await sleep(0)
10✔
1908

1909
    @classmethod
10✔
1910
    async def checkpoint_if_cancelled(cls) -> None:
7✔
1911
        task = current_task()
10✔
1912
        if task is None:
10✔
1913
            return
×
1914

1915
        try:
10✔
1916
            cancel_scope = _task_states[task].cancel_scope
10✔
1917
        except KeyError:
10✔
1918
            return
10✔
1919

1920
        while cancel_scope:
10✔
1921
            if cancel_scope.cancel_called:
10✔
1922
                await sleep(0)
10✔
1923
            elif cancel_scope.shield:
10✔
1924
                break
9✔
1925
            else:
1926
                cancel_scope = cancel_scope._parent_scope
10✔
1927

1928
    @classmethod
10✔
1929
    async def cancel_shielded_checkpoint(cls) -> None:
7✔
1930
        with CancelScope(shield=True):
10✔
1931
            await sleep(0)
10✔
1932

1933
    @classmethod
10✔
1934
    async def sleep(cls, delay: float) -> None:
7✔
1935
        await sleep(delay)
10✔
1936

1937
    @classmethod
10✔
1938
    def create_cancel_scope(
10✔
1939
        cls, *, deadline: float = math.inf, shield: bool = False
1940
    ) -> CancelScope:
1941
        return CancelScope(deadline=deadline, shield=shield)
10✔
1942

1943
    @classmethod
10✔
1944
    def current_effective_deadline(cls) -> float:
7✔
1945
        try:
9✔
1946
            cancel_scope = _task_states[
9✔
1947
                current_task()  # type: ignore[index]
1948
            ].cancel_scope
1949
        except KeyError:
×
1950
            return math.inf
×
1951

1952
        deadline = math.inf
9✔
1953
        while cancel_scope:
9✔
1954
            deadline = min(deadline, cancel_scope.deadline)
9✔
1955
            if cancel_scope._cancel_called:
9✔
1956
                deadline = -math.inf
9✔
1957
                break
9✔
1958
            elif cancel_scope.shield:
9✔
1959
                break
9✔
1960
            else:
1961
                cancel_scope = cancel_scope._parent_scope
9✔
1962

1963
        return deadline
9✔
1964

1965
    @classmethod
10✔
1966
    def create_task_group(cls) -> abc.TaskGroup:
7✔
1967
        return TaskGroup()
10✔
1968

1969
    @classmethod
10✔
1970
    def create_event(cls) -> abc.Event:
7✔
1971
        return Event()
10✔
1972

1973
    @classmethod
10✔
1974
    def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
7✔
1975
        return CapacityLimiter(total_tokens)
9✔
1976

1977
    @classmethod
10✔
1978
    async def run_sync_in_worker_thread(
10✔
1979
        cls,
1980
        func: Callable[..., T_Retval],
1981
        args: tuple[Any, ...],
1982
        cancellable: bool = False,
1983
        limiter: abc.CapacityLimiter | None = None,
1984
    ) -> T_Retval:
1985
        await cls.checkpoint()
10✔
1986

1987
        # If this is the first run in this event loop thread, set up the necessary
1988
        # variables
1989
        try:
10✔
1990
            idle_workers = _threadpool_idle_workers.get()
10✔
1991
            workers = _threadpool_workers.get()
10✔
1992
        except LookupError:
10✔
1993
            idle_workers = deque()
10✔
1994
            workers = set()
10✔
1995
            _threadpool_idle_workers.set(idle_workers)
10✔
1996
            _threadpool_workers.set(workers)
10✔
1997

1998
        async with (limiter or cls.current_default_thread_limiter()):
10✔
1999
            with CancelScope(shield=not cancellable):
10✔
2000
                future: asyncio.Future = asyncio.Future()
10✔
2001
                root_task = find_root_task()
10✔
2002
                if not idle_workers:
10✔
2003
                    worker = WorkerThread(root_task, workers, idle_workers)
10✔
2004
                    worker.start()
10✔
2005
                    workers.add(worker)
10✔
2006
                    root_task.add_done_callback(worker.stop)
10✔
2007
                else:
2008
                    worker = idle_workers.pop()
10✔
2009

2010
                    # Prune any other workers that have been idle for MAX_IDLE_TIME
2011
                    # seconds or longer
2012
                    now = cls.current_time()
10✔
2013
                    while idle_workers:
10✔
2014
                        if (
9✔
2015
                            now - idle_workers[0].idle_since
2016
                            < WorkerThread.MAX_IDLE_TIME
2017
                        ):
2018
                            break
9✔
2019

2020
                        expired_worker = idle_workers.popleft()
×
2021
                        expired_worker.root_task.remove_done_callback(
×
2022
                            expired_worker.stop
2023
                        )
2024
                        expired_worker.stop()
×
2025

2026
                context = copy_context()
10✔
2027
                context.run(sniffio.current_async_library_cvar.set, None)
10✔
2028
                worker.queue.put_nowait((context, func, args, future))
10✔
2029
                return await future
10✔
2030

2031
    @classmethod
10✔
2032
    def run_async_from_thread(
7✔
2033
        cls,
2034
        func: Callable[..., Awaitable[T_Retval]],
2035
        args: tuple[Any, ...],
2036
        token: object,
2037
    ) -> T_Retval:
2038
        loop = cast(AbstractEventLoop, token)
10✔
2039
        f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
10✔
2040
            func(*args), loop
2041
        )
2042
        return f.result()
10✔
2043

2044
    @classmethod
10✔
2045
    def run_sync_from_thread(
7✔
2046
        cls, func: Callable[..., T_Retval], args: tuple[Any, ...], token: object
2047
    ) -> T_Retval:
2048
        @wraps(func)
10✔
2049
        def wrapper() -> None:
7✔
2050
            try:
10✔
2051
                f.set_result(func(*args))
10✔
2052
            except BaseException as exc:
10✔
2053
                f.set_exception(exc)
10✔
2054
                if not isinstance(exc, Exception):
10✔
2055
                    raise
×
2056

2057
        f: concurrent.futures.Future[T_Retval] = Future()
10✔
2058
        loop = cast(AbstractEventLoop, token)
10✔
2059
        loop.call_soon_threadsafe(wrapper)
10✔
2060
        return f.result()
10✔
2061

2062
    @classmethod
10✔
2063
    def create_blocking_portal(cls) -> abc.BlockingPortal:
7✔
2064
        return BlockingPortal()
10✔
2065

2066
    @classmethod
10✔
2067
    async def open_process(
10✔
2068
        cls,
2069
        command: str | bytes | Sequence[str | bytes],
2070
        *,
2071
        shell: bool,
2072
        stdin: int | IO[Any] | None,
2073
        stdout: int | IO[Any] | None,
2074
        stderr: int | IO[Any] | None,
2075
        cwd: str | bytes | PathLike | None = None,
2076
        env: Mapping[str, str] | None = None,
2077
        start_new_session: bool = False,
2078
    ) -> Process:
2079
        await cls.checkpoint()
8✔
2080
        if shell:
8✔
2081
            process = await asyncio.create_subprocess_shell(
8✔
2082
                cast("str | bytes", command),
2083
                stdin=stdin,
2084
                stdout=stdout,
2085
                stderr=stderr,
2086
                cwd=cwd,
2087
                env=env,
2088
                start_new_session=start_new_session,
2089
            )
2090
        else:
2091
            process = await asyncio.create_subprocess_exec(
8✔
2092
                *command,
2093
                stdin=stdin,
2094
                stdout=stdout,
2095
                stderr=stderr,
2096
                cwd=cwd,
2097
                env=env,
2098
                start_new_session=start_new_session,
2099
            )
2100

2101
        stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
8✔
2102
        stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
8✔
2103
        stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
8✔
2104
        return Process(process, stdin_stream, stdout_stream, stderr_stream)
8✔
2105

2106
    @classmethod
10✔
2107
    def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
7✔
2108
        kwargs = (
8✔
2109
            {"name": "AnyIO process pool shutdown task"} if _native_task_names else {}
2110
        )
2111
        create_task(_shutdown_process_pool_on_exit(workers), **kwargs)
8✔
2112
        find_root_task().add_done_callback(
8✔
2113
            partial(_forcibly_shutdown_process_pool_on_exit, workers)
2114
        )
2115

2116
    @classmethod
10✔
2117
    async def connect_tcp(
10✔
2118
        cls, host: str, port: int, local_address: IPSockAddrType | None = None
2119
    ) -> abc.SocketStream:
2120
        transport, protocol = cast(
10✔
2121
            Tuple[asyncio.Transport, StreamProtocol],
2122
            await get_running_loop().create_connection(
2123
                StreamProtocol, host, port, local_addr=local_address
2124
            ),
2125
        )
2126
        transport.pause_reading()
10✔
2127
        return SocketStream(transport, protocol)
10✔
2128

2129
    @classmethod
10✔
2130
    async def connect_unix(cls, path: str) -> abc.UNIXSocketStream:
7✔
2131
        await cls.checkpoint()
7✔
2132
        loop = get_running_loop()
7✔
2133
        raw_socket = socket.socket(socket.AF_UNIX)
7✔
2134
        raw_socket.setblocking(False)
7✔
2135
        while True:
5✔
2136
            try:
7✔
2137
                raw_socket.connect(path)
7✔
2138
            except BlockingIOError:
7✔
2139
                f: asyncio.Future = asyncio.Future()
×
2140
                loop.add_writer(raw_socket, f.set_result, None)
×
2141
                f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2142
                await f
×
2143
            except BaseException:
7✔
2144
                raw_socket.close()
7✔
2145
                raise
7✔
2146
            else:
2147
                return UNIXSocketStream(raw_socket)
7✔
2148

2149
    @classmethod
10✔
2150
    def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
7✔
2151
        return TCPSocketListener(sock)
10✔
2152

2153
    @classmethod
10✔
2154
    def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
7✔
2155
        return UNIXSocketListener(sock)
7✔
2156

2157
    @classmethod
10✔
2158
    async def create_udp_socket(
7✔
2159
        cls,
2160
        family: AddressFamily,
2161
        local_address: IPSockAddrType | None,
2162
        remote_address: IPSockAddrType | None,
2163
        reuse_port: bool,
2164
    ) -> UDPSocket | ConnectedUDPSocket:
2165
        transport, protocol = await get_running_loop().create_datagram_endpoint(
9✔
2166
            DatagramProtocol,
2167
            local_addr=local_address,
2168
            remote_addr=remote_address,
2169
            family=family,
2170
            reuse_port=reuse_port,
2171
        )
2172
        if protocol.exception:
9✔
2173
            transport.close()
×
2174
            raise protocol.exception
×
2175

2176
        if not remote_address:
9✔
2177
            return UDPSocket(transport, protocol)
9✔
2178
        else:
2179
            return ConnectedUDPSocket(transport, protocol)
9✔
2180

2181
    @classmethod
10✔
2182
    async def create_unix_datagram_socket(  # type: ignore[override]
7✔
2183
        cls, raw_socket: socket.socket, remote_path: str | None
2184
    ) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
2185
        await cls.checkpoint()
7✔
2186
        loop = get_running_loop()
7✔
2187

2188
        if remote_path:
7✔
2189
            while True:
5✔
2190
                try:
7✔
2191
                    raw_socket.connect(remote_path)
7✔
2192
                except BlockingIOError:
×
2193
                    f: asyncio.Future = asyncio.Future()
×
2194
                    loop.add_writer(raw_socket, f.set_result, None)
×
2195
                    f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
×
2196
                    await f
×
2197
                except BaseException:
×
2198
                    raw_socket.close()
×
2199
                    raise
×
2200
                else:
2201
                    return ConnectedUNIXDatagramSocket(raw_socket)
7✔
2202
        else:
2203
            return UNIXDatagramSocket(raw_socket)
7✔
2204

2205
    @classmethod
10✔
2206
    async def getaddrinfo(
10✔
2207
        cls,
2208
        host: bytes | str | None,
2209
        port: str | int | None,
2210
        *,
2211
        family: int | AddressFamily = 0,
2212
        type: int | SocketKind = 0,
2213
        proto: int = 0,
2214
        flags: int = 0,
2215
    ) -> list[
2216
        tuple[
2217
            AddressFamily,
2218
            SocketKind,
2219
            int,
2220
            str,
2221
            tuple[str, int] | tuple[str, int, int, int],
2222
        ]
2223
    ]:
2224
        return await get_running_loop().getaddrinfo(
10✔
2225
            host, port, family=family, type=type, proto=proto, flags=flags
2226
        )
2227

2228
    @classmethod
10✔
2229
    async def getnameinfo(
10✔
2230
        cls, sockaddr: IPSockAddrType, flags: int = 0
2231
    ) -> tuple[str, str]:
2232
        return await get_running_loop().getnameinfo(sockaddr, flags)
9✔
2233

2234
    @classmethod
10✔
2235
    async def wait_socket_readable(cls, sock: socket.socket) -> None:
7✔
2236
        await cls.checkpoint()
×
2237
        try:
×
2238
            read_events = _read_events.get()
×
2239
        except LookupError:
×
2240
            read_events = {}
×
2241
            _read_events.set(read_events)
×
2242

2243
        if read_events.get(sock):
×
2244
            raise BusyResourceError("reading from") from None
×
2245

2246
        loop = get_running_loop()
×
2247
        event = read_events[sock] = asyncio.Event()
×
2248
        loop.add_reader(sock, event.set)
×
2249
        try:
×
2250
            await event.wait()
×
2251
        finally:
2252
            if read_events.pop(sock, None) is not None:
×
2253
                loop.remove_reader(sock)
×
2254
                readable = True
×
2255
            else:
2256
                readable = False
×
2257

2258
        if not readable:
×
2259
            raise ClosedResourceError
×
2260

2261
    @classmethod
10✔
2262
    async def wait_socket_writable(cls, sock: socket.socket) -> None:
7✔
2263
        await cls.checkpoint()
×
2264
        try:
×
2265
            write_events = _write_events.get()
×
2266
        except LookupError:
×
2267
            write_events = {}
×
2268
            _write_events.set(write_events)
×
2269

2270
        if write_events.get(sock):
×
2271
            raise BusyResourceError("writing to") from None
×
2272

2273
        loop = get_running_loop()
×
2274
        event = write_events[sock] = asyncio.Event()
×
2275
        loop.add_writer(sock.fileno(), event.set)
×
2276
        try:
×
2277
            await event.wait()
×
2278
        finally:
2279
            if write_events.pop(sock, None) is not None:
×
2280
                loop.remove_writer(sock)
×
2281
                writable = True
×
2282
            else:
2283
                writable = False
×
2284

2285
        if not writable:
×
2286
            raise ClosedResourceError
×
2287

2288
    @classmethod
10✔
2289
    def current_default_thread_limiter(cls) -> CapacityLimiter:
7✔
2290
        try:
10✔
2291
            return _default_thread_limiter.get()
10✔
2292
        except LookupError:
10✔
2293
            limiter = CapacityLimiter(40)
10✔
2294
            _default_thread_limiter.set(limiter)
10✔
2295
            return limiter
10✔
2296

2297
    @classmethod
10✔
2298
    def open_signal_receiver(
7✔
2299
        cls, *signals: Signals
2300
    ) -> ContextManager[AsyncIterator[Signals]]:
2301
        return _SignalReceiver(signals)
8✔
2302

2303
    @classmethod
10✔
2304
    def get_current_task(cls) -> TaskInfo:
7✔
2305
        return _create_task_info(current_task())  # type: ignore[arg-type]
10✔
2306

2307
    @classmethod
10✔
2308
    def get_running_tasks(cls) -> list[TaskInfo]:
7✔
2309
        return [_create_task_info(task) for task in all_tasks() if not task.done()]
10✔
2310

2311
    @classmethod
10✔
2312
    async def wait_all_tasks_blocked(cls) -> None:
7✔
2313
        await cls.checkpoint()
10✔
2314
        this_task = current_task()
10✔
2315
        while True:
7✔
2316
            for task in all_tasks():
10✔
2317
                if task is this_task:
10✔
2318
                    continue
10✔
2319

2320
                waiter = task._fut_waiter  # type: ignore[attr-defined]
10✔
2321
                if waiter is None or waiter.done():
10✔
2322
                    await sleep(0.1)
10✔
2323
                    break
10✔
2324
            else:
2325
                return
10✔
2326

2327
    @classmethod
10✔
2328
    def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
7✔
2329
        return TestRunner(**options)
10✔
2330

2331

2332
backend_class = AsyncIOBackend
10✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc